diff --git a/.github/ISSUE_TEMPLATE/bug-performance-issue.md b/.github/ISSUE_TEMPLATE/bug-performance-issue.md new file mode 100644 index 0000000000000000000000000000000000000000..c590a962cb7b7cb4760be953a8d64fd96a0381a9 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-performance-issue.md @@ -0,0 +1,35 @@ +-------------------------------------------------------------------------------- + +name: Bug/Performance Issue about: Use this template for reporting a bug or a +performance issue. + +-------------------------------------------------------------------------------- + +Please make sure that this is a bug. As per our +[GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md) +we only address code/doc bugs, performance issues, feature requests and +build/installation issues on GitHub. tag:bug_template + +**System information** - Have I written custom code (as opposed to using a stock +example script provided in TensorFlow): - OS Platform and Distribution (e.g., +Linux Ubuntu 16.04): - Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if +the issue happens on mobile device: - TensorFlow installed from (source or +binary): - TensorFlow version (use command below): - Python version: - Bazel +version (if compiling from source): - GCC/Compiler version (if compiling from +source): - CUDA/cuDNN version: - GPU model and memory: + +You can collect some of this information using our environment capture +[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh) +You can also obtain the TensorFlow version with python -c "import tensorflow as +tf; print(tf.GIT_VERSION, tf.VERSION)" + +**Describe the current behavior** + +**Describe the expected behavior** + +**Code to reproduce the issue** Provide a reproducible test case that is the +bare minimum necessary to generate the problem. + +**Other info / logs** Include any logs or source code that would be helpful to +diagnose the problem. If including tracebacks, please include the full +traceback. Large logs and files should be attached. diff --git a/.github/ISSUE_TEMPLATE/build-installation-issue.md b/.github/ISSUE_TEMPLATE/build-installation-issue.md new file mode 100644 index 0000000000000000000000000000000000000000..fac9ddfbd70611f6448142d25cde6935a8e63b4f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/build-installation-issue.md @@ -0,0 +1,27 @@ +-------------------------------------------------------------------------------- + +name: Build/Installation Issue about: Use this template for build/installation +issues + +-------------------------------------------------------------------------------- + +Please make sure that this is a build/installation issue. As per our +[GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md) +we only address code/doc bugs, performance issues, feature requests and +build/installation issues on GitHub. tag:build_template + +**System information** - OS Platform and Distribution (e.g., Linux Ubuntu +16.04): - Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue +happens on mobile device: - TensorFlow installed from (source or binary): - +TensorFlow version: - Python version: - Installed using virtualenv? pip? +conda?: - Bazel version (if compiling from source): - GCC/Compiler version (if +compiling from source): - CUDA/cuDNN version: - GPU model and memory: + +**Describe the problem** + +**Provide the exact sequence of commands / steps that you executed before +running into the problem** + +**Any other info / logs** Include any logs or source code that would be helpful +to diagnose the problem. If including tracebacks, please include the full +traceback. Large logs and files should be attached. diff --git a/.github/ISSUE_TEMPLATE/documentation-issue.md b/.github/ISSUE_TEMPLATE/documentation-issue.md new file mode 100644 index 0000000000000000000000000000000000000000..610da5dd467b7bc80d4eab20310e15ad6fbcf13b --- /dev/null +++ b/.github/ISSUE_TEMPLATE/documentation-issue.md @@ -0,0 +1,18 @@ +-------------------------------------------------------------------------------- + +name: Documentation Issue about: Use this template for documentation related +issues + +-------------------------------------------------------------------------------- + +Please make sure that this is a documentation issue. As per our +[GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md) +we only address code/doc bugs, performance issues, feature requests and +build/installation issues on GitHub. tag:doc_template + +**System information** - TensorFlow version: - Doc Link: + +**Describe the documentation issue** + +**We welcome contributions by users. Will you be able to update submit a PR to +fix the doc Issue?** diff --git a/.github/ISSUE_TEMPLATE/feature-request.md b/.github/ISSUE_TEMPLATE/feature-request.md new file mode 100644 index 0000000000000000000000000000000000000000..9f06e1759ff5ff57a72408fd5a74e20b5e7e6ef0 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature-request.md @@ -0,0 +1,21 @@ +-------------------------------------------------------------------------------- + +name: Feature Request about: Use this template for raising a feature request + +-------------------------------------------------------------------------------- + +Please make sure that this is a feature request. As per our +[GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md) +we only address code/doc bugs, performance issues, feature requests and +build/installation issues on GitHub. tag:feature_template + +**System information** - TensorFlow version (you are using): - Are you willing +to contribute it (Yes/No): + +**Describe the feature and the current behavior/state.** + +**Will this change the current api? How?** + +**Who will benefit with this feature?** + +**Any Other info.** diff --git a/.github/ISSUE_TEMPLATE/other-issues.md b/.github/ISSUE_TEMPLATE/other-issues.md new file mode 100644 index 0000000000000000000000000000000000000000..b53bdb3c1689398ef7624ae858dba61b0f5e0a54 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/other-issues.md @@ -0,0 +1,21 @@ +-------------------------------------------------------------------------------- + +name: Other Issues about: Use this template for any other non-support related +issues + +-------------------------------------------------------------------------------- + +This template is for miscellaneous issues not covered by the other issue +categories. + +For questions on how to work with TensorFlow, or support for problems that are +not verified bugs in TensorFlow, please go to +[StackOverflow](https://stackoverflow.com/questions/tagged/tensorflow). + +If you are reporting a vulnerability, please use the +[dedicated reporting process](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md). + +For high-level discussions about TensorFlow, please post to +discuss@tensorflow.org, for questions about the development or internal workings +of TensorFlow, or if you would like to know how to contribute to TensorFlow, +please post to developers@tensorflow.org. diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 5fff9d05a1c589636bc9c711e6eb7cc4aba86b2f..20601eaf611d98f78382a7d260629e72e24a07c0 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -7,19 +7,22 @@ In the interest of fostering an open and welcoming environment, we as contributo Examples of behavior that contributes to creating a positive environment include: -* Using welcoming and inclusive language -* Being respectful of differing viewpoints and experiences -* Gracefully accepting constructive criticism -* Focusing on what is best for the community -* Showing empathy towards other community members +* Using welcoming and inclusive language. +* Being respectful of differing viewpoints and experiences. +* Gracefully accepting constructive criticism. +* Focusing on what is best for the community. +* Showing empathy towards other community members. Examples of unacceptable behavior by participants include: -* The use of sexualized language or imagery and unwelcome sexual attention or advances -* Trolling, insulting/derogatory comments, and personal or political attacks -* Public or private harassment -* Publishing others' private information, such as a physical or electronic address, without explicit permission -* Conduct which could reasonably be considered inappropriate for the forum in which it occurs. +* The use of sexualized language or imagery and unwelcome sexual attention or + advances. +* Trolling, insulting/derogatory comments, and personal or political attacks. +* Public or private harassment. +* Publishing others' private information, such as a physical or electronic + address, without explicit permission. +* Conduct which could reasonably be considered inappropriate for the forum in + which it occurs. All TensorFlow forums and spaces are meant for professional interactions, and any behavior which could reasonably be considered inappropriate in a professional setting is unacceptable. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f598999f351c10f8bd01dfbd3ad8897f19d570e8..4a296f265f7b9521c46d350cec26ff199f43eb6c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -31,8 +31,12 @@ Follow either of the two links above to access the appropriate CLA and instructi If you have improvements to TensorFlow, send us your pull requests! For those just getting started, Github has a [howto](https://help.github.com/articles/using-pull-requests/). -TensorFlow team members will be assigned to review your pull requests. Once the pull requests are approved and pass continuous integration checks, we will merge the pull requests. -For some pull requests, we will apply the patch for each pull request to our internal version control system first, and export the change out as a new commit later, at which point the original pull request will be closed. The commits in the pull request will be squashed into a single commit with the pull request creator as the author. These pull requests will be labeled as pending merge internally. +TensorFlow team members will be assigned to review your pull requests. Once the +pull requests are approved and pass continuous integration checks, a TensorFlow +team member will apply `ready to pull` label to your change. This means we are +working on getting your pull request submitted to our internal repository. After +the change has been submitted internally, your pull request will be merged +automatically on GitHub. If you want to contribute but you're not sure where to start, take a look at the [issues with the "contributions welcome" label](https://github.com/tensorflow/tensorflow/labels/stat%3Acontributions%20welcome). diff --git a/ISSUES.md b/ISSUES.md new file mode 100644 index 0000000000000000000000000000000000000000..2b330e8e0a8a3f64753cfb7a2e2362222439312d --- /dev/null +++ b/ISSUES.md @@ -0,0 +1,9 @@ +If you open a GitHub Issue, here is our policy: 1. It must be a bug/performance +issue or a feature request or a build issue or a documentation issue (for small +doc fixes please send a PR instead). 2. Make sure the Issue Template is filled +out. 3. The issue should be related to the repo it is created in. + +**Here's why we have this policy:** We want to focus on the work that benefits +the whole community, e.g., fixing bugs and adding features. Individual support +should be seeked on StackOverflow or other non-GitHub channels. It helps us to +address bugs and feature requests in a timely manner. diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md index 52faed9297cfcaf8c93bb9c79686c9258a53c560..b3d84ad8c948df9459a8e8afb029785d6f6ad335 100644 --- a/ISSUE_TEMPLATE.md +++ b/ISSUE_TEMPLATE.md @@ -29,9 +29,11 @@ You can collect some of this information using our environment capture script: https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh -You can obtain the TensorFlow version with +You can obtain the TensorFlow version with: +```bash python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)" +``` ### Describe the problem Describe the problem clearly here. Be sure to convey here why it's a bug in TensorFlow or a feature request. diff --git a/README.md b/README.md index 57efb876c9afaf9fe76c4ced4e6a1572e9241edf..0c8d4d4ef08ec2598bf55ec1f168323f6ad755e1 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,21 @@ subscribing to [announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce). ## Installation -*See [Installing TensorFlow](https://www.tensorflow.org/install) for instructions on how to install our release binaries or how to build from source.* + +To install the current release for CPU-only: + +``` +pip install tensorflow +``` + +Use the GPU package for CUDA-enabled GPU cards: + +``` +pip install tensorflow-gpu +``` + +*See [Installing TensorFlow](https://www.tensorflow.org/install) for detailed +instructions, and how to build from source.* People who are a little more adventurous can also try our nightly binaries: @@ -65,9 +79,10 @@ guidelines](CONTRIBUTING.md). This project adheres to TensorFlow's uphold this code.** **We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for -tracking requests and bugs. So please see -[TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss) for general questions -and discussion, and please direct specific questions to [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).** +tracking requests and bugs, so please see +[TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss) +for general questions and discussion, and please direct specific questions to +[Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).** The TensorFlow project strives to abide by generally accepted best practices in open-source software development: @@ -93,14 +108,14 @@ The TensorFlow project strives to abide by generally accepted best practices in ### Community Supported Builds -| Build Type | Status | Artifacts | -| --- | --- | --- | -| **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA | -| **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA | -| **IBM ppc64le GPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/) | TBA | -| **Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | -| **Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild)|[1.10.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp27-cp27mu-linux_x86_64.whl)
[1.10.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp35-cp35m-linux_x86_64.whl)
[1.10.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp36-cp36m-linux_x86_64.whl) | - +Build Type | Status | Artifacts +---------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- +**IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA +**IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA +**IBM ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) +**IBM ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) +**Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) +**Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.10.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp27-cp27mu-linux_x86_64.whl)
[1.10.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp35-cp35m-linux_x86_64.whl)
[1.10.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp36-cp36m-linux_x86_64.whl) ## For more information * [TensorFlow Website](https://www.tensorflow.org) diff --git a/RELEASE.md b/RELEASE.md index 20e1d9217b7684e696d0abf427eef9ab9548d1b7..2b00d06580d925a4afed5753afb8f51f0ebac99f 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,74 @@ +# Release 1.12.0 + +## Major Features and Improvements + +* Keras models can now be directly exported to the SavedModel + format(`tf.contrib.saved_model.save_keras_model()`) and used with Tensorflow + Serving. +* Keras models now support evaluating with a `tf.data.Dataset`. +* TensorFlow binaries are built with XLA support linked in by default. + +## Bug Fixes and Other Changes + +* tf.data: + * tf.data users can now represent, get, and set options of TensorFlow + input pipelines using `tf.data.Options()`, `tf.data.Dataset.options()`, + and `tf.data.Dataset.with_options()` respectively. + * New `tf.data.Dataset.reduce()` API allows users to reduce a finite + dataset to a single element using a user-provided reduce function. + * New `tf.data.Dataset.window()` API allows users to create finite windows + of input dataset; when combined with the `tf.data.Dataset.reduce()` API, + this allows users to implement customized batching. + * All C++ code moves to the `tensorflow::data` namespace. + * Add support for `num_parallel_calls` to `tf.data.Dataset.interleave`. +* `tf.contrib`: + * Remove `tf.contrib.linalg`. `tf.linalg` should be used instead. + * Replace any calls to `tf.contrib.get_signature_def_by_key(metagraph_def, + signature_def_key)` with + `meta_graph_def.signature_def[signature_def_key]`. Catching a ValueError + exception thrown by `tf.contrib.get_signature_def_by_key` should be + replaced by catching a KeyError exception. +* `tf.contrib.data` + * Deprecate, and replace by tf.data.experimental. +* Other: + * Instead of jemalloc, revert back to using system malloc since it + simplifies build and has comparable performance. + * Remove integer types from `tf.nn.softplus` and `tf.nn.softsign` OpDefs. + This is a bugfix; these ops were never meant to support integers. + * Allow subslicing Tensors with a single dimension. + * Add option to calculate string length in Unicode characters + * Add functionality to SubSlice a tensor. + * Add searchsorted (ie lower/upper_bound) op. + * Add model explainability to Boosted Trees. + * Support negative positions for tf.substr + * There was previously a bug in the bijector_impl where the + _reduce_jacobian_det_over_event does not handle scalar ILDJ + implementations properly. + * In tf eager execution, allow re-entering a GradientTape context + * Add tf_api_version flag. If --define=tf_api_version=2 flag is passed in, + then bazel will build TensorFlow API version 2.0. Note that TensorFlow + 2.0 is under active development and has no guarantees at this point. + * Add additional compression options to TfRecordWriter + * Performance improvements for regex full match operations. + * Replace tf.GraphKeys.VARIABLES with `tf.GraphKeys.GLOBAL_VARIABLES` + * Remove unused dynamic learning rate support. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +(David) Siu-Kei Muk, Ag Ramesh, Anton Dmitriev, Artem Sobolev, Avijit-Nervana, +Bairen Yi, Bruno Goncalves, By Shen, candy.dc, Cheng Chen, Clayne Robison, +coder3101, Dao Zhang, Elms, Fei Hu, feiquan, Geoffrey Irving, Guozhong Zhuang, +hellcom, Hoeseong Kim, imsheridan, Jason Furmanek, Jason Zaman, Jenny Sahng, +jiefangxuanyan, Johannes Bannhofer, Jonathan Homer, Koan-Sin Tan, kouml, Loo +Rong Jie, Lukas Geiger, manipopopo, Ming Li, Moritz KröGer, Naurril, Niranjan +Hasabnis, Pan Daoxin, Peng Yu, pengwa, rasmi, Roger Xin, Roland Fernandez, Sami +Kama, Samuel Matzek, Sangjung Woo, Sergei Lebedev, Sergii Khomenko, shaohua, +Shaohua Zhang, Shujian2015, Sunitha Kambhampati, tomguluson92, ViníCius Camargo, +wangsiyu, weidankong, Wen-Heng (Jack) Chung, William D. Irons, Xin Jin, Yan +Facai (颜发才), Yanbo Liang, Yash Katariya, Yong Tang, 在原佐为 + # Release 1.11.0 ## Major Features and Improvements @@ -20,51 +91,84 @@ ## Bug Fixes and Other Changes -* C++: - * Changed the signature of SessionFactory::NewSession so that it can return a meaningful error message on failure. -* tf.data: - * Remove `num_parallel_parser_calls` argument from `tf.contrib.data.make_csv_dataset()`. [tf.data] Remove `num_parallel_parser_calls` argument from `tf.contrib.data.make_csv_dataset()`. - * `tf.data.Dataset.list_files()` raises an exception at initialization time if the argument matches no files. - * Renamed BigTable class to BigtableTable for clarity - * Document use of the Cloud Bigtable API - * Adding `tf.contrib.data.reduce_dataset` which can be used to reduce a dataset to a single element. - * Generalization of `tf.contrib.data.sliding_window_batch`. -* INC: - * Runtime improvements to triangular solve. -* `tf.contrib`: - * Add an `implementation` argument to `tf.keras.layers.LocallyConnected2D` and `tf.keras.layers.LocallyConnected1D`. The new mode (`implementation=2`) performs forward pass as a single dense matrix multiplication, allowing dramatic speedups in certain scenarios (but worse performance in others - see docstring). The option also allows to use `padding=same`. - * Add documentation clarifying the differences between tf.fill and tf.constant. - * Add experimental IndexedDatasets. - * Add selective registration target using the lite proto runtime. - * Add simple Tensor and DataType classes to TensorFlow Lite Java - * Add support for bitcasting to/from uint32 and uint64. - * Added a subclass of Estimator that can be created from a SavedModel (SavedModelEstimator). - * Adds leaf index modes as an argument. - * Allow a different output shape from the input in tf.contrib.image.transform. - * Change the state_size order of the StackedRNNCell to be natural order. To keep the existing behavior, user can add reverse_state_order=True when constructing the StackedRNNCells. - * Deprecate self.test_session() in favor of self.session() or self.cached_session(). - * Directly import tensor.proto.h (the transitive import will be removed from tensor.h soon) - * Estimator.train() now supports tf.contrib.summary.\* summaries out of the box; each call to .train() will now create a separate tfevents file rather than re-using a shared one. - * Fix FTRL L2-shrinkage behavior: the gradient from the L2 shrinkage term should not end up in the accumulator. - * Fix toco compilation/execution on Windows - * GoogleZoneProvider class added to detect which Google Cloud Engine zone tensorflow is running in. - * It is now safe to call any of the C API's TF_Delete\* functions on nullptr - * Log some errors on Android to logcat - * Match FakeQuant numerics in TFLite to improve accuracy of TFLite quantized inference models. - * Optional bucket location check for the GCS Filesystem. - * Performance enhancements for StringSplitOp & StringSplitV2Op. - * Performance improvements for regex replace operations. - * TFRecordWriter now raises an error if .write() fails. - * TPU: More helpful error messages in TPUClusterResolvers. - * The legacy_init_op argument to SavedModelBuilder methods for adding MetaGraphs has been deprecated. Please use the equivalent main_op argument instead. As part of this, we now explicitly check for a single main_op or legacy_init_op at the time of SavedModel building, whereas the check on main_op was previously only done at load time. - * The protocol used for Estimator training is now configurable in RunConfig. - * Triangular solve performance improvements. - * Unify RNN cell interface between TF and Keras. Add new get_initial_state() to Keras and TF RNN cell, which will use to replace the existing zero_state() method. - * Update initialization of variables in Keras. - * Updates to "constrained_optimization" in tensorflow/contrib. - * boosted trees: adding pruning mode - * tf.train.Checkpoint does not delete old checkpoints by default. - * tfdbg: Limit the total disk space occupied by dumped tensor data to 100 GBytes. Add environment variable `TFDBG_DISK_BYTES_LIMIT` to allow adjustment of this upper limit. +* C++: + * Changed the signature of SessionFactory::NewSession so that it can + return a meaningful error message on failure. +* tf.data: + * Remove `num_parallel_parser_calls` argument from + `tf.contrib.data.make_csv_dataset()`. [tf.data] Remove + `num_parallel_parser_calls` argument from + `tf.contrib.data.make_csv_dataset()`. + * `tf.data.Dataset.list_files()` raises an exception at initialization + time if the argument matches no files. + * Renamed BigTable class to BigtableTable for clarity + * Document use of the Cloud Bigtable API + * Add `tf.contrib.data.reduce_dataset` which can be used to reduce a + dataset to a single element. + * Generalization of `tf.contrib.data.sliding_window_batch`. +* INC: + * Runtime improvements to triangular solve. +* `tf.contrib`: + * Add an `implementation` argument to `tf.keras.layers.LocallyConnected2D` + and `tf.keras.layers.LocallyConnected1D`. The new mode + (`implementation=2`) performs forward pass as a single dense matrix + multiplication, allowing dramatic speedups in certain scenarios (but + worse performance in others - see docstring). The option also allows to + use `padding=same`. + * Add documentation clarifying the differences between tf.fill and + tf.constant. + * Add experimental IndexedDatasets. + * Add selective registration target using the lite proto runtime. + * Add simple Tensor and DataType classes to TensorFlow Lite Java + * Add support for bitcasting to/from uint32 and uint64. + * Added a subclass of Estimator that can be created from a SavedModel + (SavedModelEstimator). + * Adds leaf index modes as an argument. + * Allow a different output shape from the input in + tf.contrib.image.transform. + * Change the state_size order of the StackedRNNCell to be natural order. + To keep the existing behavior, user can add reverse_state_order=True + when constructing the StackedRNNCells. + * Deprecate self.test_session() in favor of self.session() or + self.cached_session(). + * Directly import tensor.proto.h (the transitive import will be removed + from tensor.h soon) + * Estimator.train() now supports tf.contrib.summary.\* summaries out of + the box; each call to .train() will now create a separate tfevents file + rather than re-using a shared one. + * Fix FTRL L2-shrinkage behavior: the gradient from the L2 shrinkage term + should not end up in the accumulator. + * Fix toco compilation/execution on Windows + * GoogleZoneProvider class added to detect which Google Cloud Engine zone + tensorflow is running in. + * It is now safe to call any of the C API's TF_Delete\* functions on + nullptr + * Log some errors on Android to logcat + * Match FakeQuant numerics in TFLite to improve accuracy of TFLite + quantized inference models. + * Optional bucket location check for the GCS Filesystem. + * Performance enhancements for StringSplitOp & StringSplitV2Op. + * Performance improvements for regex replace operations. + * TFRecordWriter now raises an error if .write() fails. + * TPU: More helpful error messages in TPUClusterResolvers. + * The legacy_init_op argument to SavedModelBuilder methods for adding + MetaGraphs has been deprecated. Please use the equivalent main_op + argument instead. As part of this, we now explicitly check for a single + main_op or legacy_init_op at the time of SavedModel building, whereas + the check on main_op was previously only done at load time. + * The protocol used for Estimator training is now configurable in + RunConfig. + * Triangular solve performance improvements. + * Unify RNN cell interface between TF and Keras. Add new + get_initial_state() to Keras and TF RNN cell, which will use to replace + the existing zero_state() method. + * Update initialization of variables in Keras. + * Updates to "constrained_optimization" in tensorflow/contrib. + * boosted trees: adding pruning mode + * tf.train.Checkpoint does not delete old checkpoints by default. + * tfdbg: Limit the total disk space occupied by dumped tensor data to 100 + GBytes. Add environment variable `TFDBG_DISK_BYTES_LIMIT` to allow + adjustment of this upper limit. ## Thanks to our Contributors diff --git a/configure.py b/configure.py index 0a3b9a7894d8ec7c381954e7dbb5f1bf42233157..b564da27227ec07713f91e925ea292b35f0f02df 100644 --- a/configure.py +++ b/configure.py @@ -35,7 +35,6 @@ except ImportError: _DEFAULT_CUDA_VERSION = '9.0' _DEFAULT_CUDNN_VERSION = '7' -_DEFAULT_NCCL_VERSION = '2.2' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0' _DEFAULT_CUDA_PATH = '/usr/local/cuda' _DEFAULT_CUDA_PATH_LINUX = '/opt/cuda' @@ -48,10 +47,13 @@ _SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16] _DEFAULT_PROMPT_ASK_ATTEMPTS = 10 -_TF_WORKSPACE_ROOT = os.path.abspath(os.path.dirname(__file__)) _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc' -_TF_BAZELRC = os.path.join(_TF_WORKSPACE_ROOT, _TF_BAZELRC_FILENAME) -_TF_WORKSPACE = os.path.join(_TF_WORKSPACE_ROOT, 'WORKSPACE') +_TF_WORKSPACE_ROOT = '' +_TF_BAZELRC = '' + +NCCL_LIB_PATHS = [ + 'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', '' +] if platform.machine() == 'ppc64le': _DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/powerpc64le-linux-gnu/' @@ -224,7 +226,7 @@ def setup_python(environ_cp): python_lib_path = default_python_lib_path environ_cp['PYTHON_LIB_PATH'] = python_lib_path - python_major_version = get_python_major_version(python_bin_path) + _ = get_python_major_version(python_bin_path) # Convert python path to Windows style before writing into bazel.rc if is_windows() or is_cygwin(): @@ -243,10 +245,10 @@ def setup_python(environ_cp): f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path) -def reset_tf_configure_bazelrc(workspace_path): +def reset_tf_configure_bazelrc(): """Reset file that contains customized config settings.""" open(_TF_BAZELRC, 'w').close() - bazelrc_path = os.path.join(workspace_path, '.bazelrc') + bazelrc_path = os.path.join(_TF_WORKSPACE_ROOT, '.bazelrc') data = [] if os.path.exists(bazelrc_path): @@ -257,12 +259,7 @@ def reset_tf_configure_bazelrc(workspace_path): if _TF_BAZELRC_FILENAME in l: continue f.write('%s\n' % l) - if is_windows(): - tf_bazelrc_path = _TF_BAZELRC.replace('\\', '/') - else: - tf_bazelrc_path = _TF_BAZELRC - f.write('import %s\n' % tf_bazelrc_path) - + f.write('import %%workspace%%/%s\n' % _TF_BAZELRC_FILENAME) def cleanup_makefile(): """Delete any leftover BUILD files from the Makefile build. @@ -386,7 +383,9 @@ def set_build_var(environ_cp, var = str(int(get_var(environ_cp, var_name, query_item, enabled_by_default))) environ_cp[var_name] = var if var == '1': - write_to_bazelrc('build --define %s=true' % option_name) + write_to_bazelrc( + 'build:%s --define %s=true' % (bazel_config_name, option_name)) + write_to_bazelrc('build --config=%s' % bazel_config_name) elif bazel_config_name is not None: # TODO(mikecase): Migrate all users of configure.py to use --config Bazel # options and not to set build configs through environment variables. @@ -498,7 +497,7 @@ def set_cc_opt_flags(environ_cp): elif is_windows(): default_cc_opt_flags = '/arch:AVX' else: - default_cc_opt_flags = '-march=native' + default_cc_opt_flags = '-march=native -Wno-sign-compare' question = ('Please specify optimization flags to use during compilation when' ' bazel option "--config=opt" is specified [Default is %s]: ' ) % default_cc_opt_flags @@ -885,7 +884,7 @@ def set_tf_cudnn_version(environ_cp): """Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION.""" ask_cudnn_version = ( 'Please specify the cuDNN version you want to use. ' - '[Leave empty to default to cuDNN %s.0]: ') % _DEFAULT_CUDNN_VERSION + '[Leave empty to default to cuDNN %s]: ') % _DEFAULT_CUDNN_VERSION for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): tf_cudnn_version = get_from_env_or_user_or_default( @@ -1042,7 +1041,7 @@ def set_tf_tensorrt_install_path(environ_cp): for lib_file in possible_files: if is_cuda_compatible(lib_file, cuda_ver, cudnn_ver): matches = nvinfer_pattern.search(lib_file) - if len(matches.groups()) == 0: + if not matches.groups(): continue ver_str = matches.group(1) ver = convert_version_to_int(ver_str) if len(ver_str) else 0 @@ -1098,7 +1097,7 @@ def set_tf_tensorrt_install_path(environ_cp): def set_tf_nccl_install_path(environ_cp): - """Set NCCL_INSTALL_PATH and TF_NCCL_VERSION. + """Set NCCL_INSTALL_PATH, NCCL_HDR_PATH and TF_NCCL_VERSION. Args: environ_cp: copy of the os.environ. @@ -1111,59 +1110,119 @@ def set_tf_nccl_install_path(environ_cp): raise ValueError('Currently NCCL is only supported on Linux platforms.') ask_nccl_version = ( - 'Please specify the NCCL version you want to use. If NCCL %s is not ' - 'installed, then you can use version 1.3 that can be fetched ' - 'automatically but it may have worse performance with multiple GPUs. ' - '[Default is %s]: ') % (_DEFAULT_NCCL_VERSION, _DEFAULT_NCCL_VERSION) + 'Please specify the locally installed NCCL version you want to use. ' + '[Default is to use https://github.com/nvidia/nccl]: ') for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): tf_nccl_version = get_from_env_or_user_or_default( - environ_cp, 'TF_NCCL_VERSION', ask_nccl_version, _DEFAULT_NCCL_VERSION) - tf_nccl_version = reformat_version_sequence(str(tf_nccl_version), 1) + environ_cp, 'TF_NCCL_VERSION', ask_nccl_version, '') + + if not tf_nccl_version: + break # No need to get install path, building the open source code. - if tf_nccl_version == '1': - break # No need to get install path, NCCL 1 is a GitHub repo. + tf_nccl_version = reformat_version_sequence(str(tf_nccl_version), 1) - # TODO(csigg): Look with ldconfig first if we can find the library in paths + # Look with ldconfig first if we can find the library in paths # like /usr/lib/x86_64-linux-gnu and the header file in the corresponding # include directory. This is where the NCCL .deb packages install them. - # Then ask the user if we should use that. Instead of a single - # NCCL_INSTALL_PATH, pass separate NCCL_LIB_PATH and NCCL_HDR_PATH to - # nccl_configure.bzl - default_nccl_path = environ_cp.get('CUDA_TOOLKIT_PATH') - ask_nccl_path = (r'Please specify the location where NCCL %s library is ' - 'installed. Refer to README.md for more details. [Default ' - 'is %s]:') % (tf_nccl_version, default_nccl_path) - nccl_install_path = get_from_env_or_user_or_default( - environ_cp, 'NCCL_INSTALL_PATH', ask_nccl_path, default_nccl_path) - - # Result returned from "read" will be used unexpanded. That make "~" - # unusable. Going through one more level of expansion to handle that. - nccl_install_path = os.path.realpath(os.path.expanduser(nccl_install_path)) - if is_windows() or is_cygwin(): - nccl_install_path = cygpath(nccl_install_path) - if is_windows(): - nccl_lib_path = 'lib/x64/nccl.lib' - elif is_linux(): - nccl_lib_path = 'lib/libnccl.so.%s' % tf_nccl_version - elif is_macos(): - nccl_lib_path = 'lib/libnccl.%s.dylib' % tf_nccl_version - - nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path) - nccl_hdr_path = os.path.join(nccl_install_path, 'include/nccl.h') - if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path): - # Set NCCL_INSTALL_PATH - environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path - write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path) - break - - # Reset and Retry - print('Invalid path to NCCL %s toolkit, %s or %s not found. Please use the ' + # First check to see if NCCL is in the ldconfig. + # If its found, use that location. + if is_linux(): + ldconfig_bin = which('ldconfig') or '/sbin/ldconfig' + nccl2_path_from_ldconfig = run_shell([ldconfig_bin, '-p']) + nccl2_path_from_ldconfig = re.search('.*libnccl.so .* => (.*)', + nccl2_path_from_ldconfig) + if nccl2_path_from_ldconfig: + nccl2_path_from_ldconfig = nccl2_path_from_ldconfig.group(1) + if os.path.exists('%s.%s' % (nccl2_path_from_ldconfig, tf_nccl_version)): + nccl_install_path = os.path.dirname(nccl2_path_from_ldconfig) + print('NCCL libraries found in ' + nccl2_path_from_ldconfig) + + # Check if this is the main system lib location + if re.search('.*linux-gnu', nccl_install_path): + trunc_nccl_install_path = '/usr' + print('This looks like a system path.') + else: + trunc_nccl_install_path = nccl_install_path + '/..' + + # Look for header + nccl_hdr_path = trunc_nccl_install_path + '/include' + print('Assuming NCCL header path is ' + nccl_hdr_path) + if os.path.exists(nccl_hdr_path + '/nccl.h'): + # Set NCCL_INSTALL_PATH + environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path + write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path) + + # Set NCCL_HDR_PATH + environ_cp['NCCL_HDR_PATH'] = nccl_hdr_path + write_action_env_to_bazelrc('NCCL_HDR_PATH', nccl_hdr_path) + break + else: + print( + 'The header for NCCL2 cannot be found. Please install the libnccl-dev package.' + ) + else: + print('NCCL2 is listed by ldconfig but the library is not found. ' + 'Your ldconfig is out of date. Please run sudo ldconfig.') + else: + # NCCL is not found in ldconfig. Ask the user for the location. + default_nccl_path = environ_cp.get('CUDA_TOOLKIT_PATH') + ask_nccl_path = ( + r'Please specify the location where NCCL %s library is ' + 'installed. Refer to README.md for more details. [Default ' + 'is %s]:') % (tf_nccl_version, default_nccl_path) + nccl_install_path = get_from_env_or_user_or_default( + environ_cp, 'NCCL_INSTALL_PATH', ask_nccl_path, default_nccl_path) + + # Result returned from "read" will be used unexpanded. That make "~" + # unusable. Going through one more level of expansion to handle that. + nccl_install_path = os.path.realpath( + os.path.expanduser(nccl_install_path)) + if is_windows() or is_cygwin(): + nccl_install_path = cygpath(nccl_install_path) + + if is_windows(): + nccl_lib_path = 'lib/x64/nccl.lib' + elif is_linux(): + nccl_lib_filename = 'libnccl.so.%s' % tf_nccl_version + nccl_lpath = '%s/lib/%s' % (nccl_install_path, nccl_lib_filename) + if not os.path.exists(nccl_lpath): + for relative_path in NCCL_LIB_PATHS: + path = '%s/%s%s' % (nccl_install_path, relative_path, + nccl_lib_filename) + if os.path.exists(path): + print('NCCL found at ' + path) + nccl_lib_path = path + break + else: + nccl_lib_path = nccl_lpath + elif is_macos(): + nccl_lib_path = 'lib/libnccl.%s.dylib' % tf_nccl_version + + nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path) + nccl_hdr_path = os.path.join( + os.path.dirname(nccl_lib_path), '../include/nccl.h') + print('Assuming NCCL header path is ' + nccl_hdr_path) + if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path): + # Set NCCL_INSTALL_PATH + environ_cp['NCCL_INSTALL_PATH'] = os.path.dirname(nccl_lib_path) + write_action_env_to_bazelrc('NCCL_INSTALL_PATH', + os.path.dirname(nccl_lib_path)) + + # Set NCCL_HDR_PATH + environ_cp['NCCL_HDR_PATH'] = os.path.dirname(nccl_hdr_path) + write_action_env_to_bazelrc('NCCL_HDR_PATH', + os.path.dirname(nccl_hdr_path)) + break + + # Reset and Retry + print( + 'Invalid path to NCCL %s toolkit, %s or %s not found. Please use the ' 'O/S agnostic package of NCCL 2' % (tf_nccl_version, nccl_lib_path, nccl_hdr_path)) - environ_cp['TF_NCCL_VERSION'] = '' + environ_cp['TF_NCCL_VERSION'] = '' else: raise UserInputError('Invalid TF_NCCL setting was provided %d ' 'times in a row. Assuming to be a scripting mistake.' % @@ -1173,7 +1232,6 @@ def set_tf_nccl_install_path(environ_cp): environ_cp['TF_NCCL_VERSION'] = tf_nccl_version write_action_env_to_bazelrc('TF_NCCL_VERSION', tf_nccl_version) - def get_native_cuda_compute_capabilities(environ_cp): """Get native cuda compute capabilities. @@ -1410,7 +1468,7 @@ def set_other_mpi_vars(environ_cp): def set_system_libs_flag(environ_cp): syslibs = environ_cp.get('TF_SYSTEM_LIBS', '') - if syslibs and syslibs != '': + if syslibs: if ',' in syslibs: syslibs = ','.join(sorted(syslibs.split(','))) else: @@ -1440,14 +1498,6 @@ def set_windows_build_flags(environ_cp): # TODO(pcloudy): Remove this flag when upgrading Bazel to 0.16.0 # Short object file path will be enabled by default. write_to_bazelrc('build --experimental_shortened_obj_file_path=true') - # When building zip file for some py_binary and py_test targets, don't - # include its dependencies. This is for: - # 1. Running python tests against the system installed TF pip package. - # 2. Avoiding redundant files in - # //tensorflow/tools/pip_package:simple_console_windows, - # which is a py_binary used during creating TF pip package. - # See https://github.com/tensorflow/tensorflow/issues/22390 - write_to_bazelrc('build --define=no_tensorflow_py_deps=true') if get_var( environ_cp, 'TF_OVERRIDE_EIGEN_STRONG_INLINE', 'Eigen strong inline', @@ -1469,26 +1519,31 @@ def config_info_line(name, help_text): def main(): + global _TF_WORKSPACE_ROOT + global _TF_BAZELRC + parser = argparse.ArgumentParser() parser.add_argument( '--workspace', type=str, - default=_TF_WORKSPACE_ROOT, + default=os.path.abspath(os.path.dirname(__file__)), help='The absolute path to your active Bazel workspace.') args = parser.parse_args() + _TF_WORKSPACE_ROOT = args.workspace + _TF_BAZELRC = os.path.join(_TF_WORKSPACE_ROOT, _TF_BAZELRC_FILENAME) + # Make a copy of os.environ to be clear when functions and getting and setting # environment variables. environ_cp = dict(os.environ) check_bazel_version('0.15.0') - reset_tf_configure_bazelrc(args.workspace) + reset_tf_configure_bazelrc() cleanup_makefile() setup_python(environ_cp) if is_windows(): - environ_cp['TF_NEED_JEMALLOC'] = '0' environ_cp['TF_NEED_OPENCL_SYCL'] = '0' environ_cp['TF_NEED_COMPUTECPP'] = '0' environ_cp['TF_NEED_OPENCL'] = '0' @@ -1497,14 +1552,11 @@ def main(): # TODO(ibiryukov): Investigate using clang as a cpu or cuda compiler on # Windows. environ_cp['TF_DOWNLOAD_CLANG'] = '0' - environ_cp['TF_ENABLE_XLA'] = '0' environ_cp['TF_NEED_MPI'] = '0' environ_cp['TF_SET_ANDROID_WORKSPACE'] = '0' if is_macos(): - environ_cp['TF_NEED_JEMALLOC'] = '0' environ_cp['TF_NEED_TENSORRT'] = '0' - environ_cp['TF_ENABLE_XLA'] = '0' # The numpy package on ppc64le uses OpenBLAS which has multi-threading # issues that lead to incorrect answers. Set OMP_NUM_THREADS=1 at @@ -1513,9 +1565,9 @@ def main(): if is_ppc64le(): write_action_env_to_bazelrc('OMP_NUM_THREADS', 1) + xla_enabled_by_default = is_linux() set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support', - True, 'xla') - + xla_enabled_by_default, 'xla') set_action_env_var(environ_cp, 'TF_NEED_OPENCL_SYCL', 'OpenCL SYCL', False) if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1': @@ -1607,19 +1659,23 @@ def main(): create_android_ndk_rule(environ_cp) create_android_sdk_rule(environ_cp) - # On Windows, we don't have MKL support and the build is always monolithic. - # So no need to print the following message. - # TODO(pcloudy): remove the following if check when they make sense on Windows - if not is_windows(): - print('Preconfigured Bazel build configs. You can use any of the below by ' - 'adding "--config=<>" to your build command. See tools/bazel.rc for ' - 'more details.') - config_info_line('mkl', 'Build with MKL support.') - config_info_line('monolithic', 'Config for mostly static monolithic build.') - config_info_line('gdr', 'Build with GDR support.') - config_info_line('verbs', 'Build with libverbs support.') - config_info_line('ngraph', 'Build with Intel nGraph support.') + print('Preconfigured Bazel build configs. You can use any of the below by ' + 'adding "--config=<>" to your build command. See .bazelrc for more ' + 'details.') + config_info_line('mkl', 'Build with MKL support.') + config_info_line('monolithic', 'Config for mostly static monolithic build.') + config_info_line('gdr', 'Build with GDR support.') + config_info_line('verbs', 'Build with libverbs support.') + config_info_line('ngraph', 'Build with Intel nGraph support.') + + print('Preconfigured Bazel build configs to DISABLE default on features:') + config_info_line('noaws', 'Disable AWS S3 filesystem support.') + config_info_line('nogcp', 'Disable GCP support.') + config_info_line('nohdfs', 'Disable HDFS support.') + config_info_line('noignite', 'Disable Apacha Ignite support.') + config_info_line('nokafka', 'Disable Apache Kafka support.') if __name__ == '__main__': main() + diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 5f73da68a2adb489d003ba121de92ae78eb1d98b..77e3baaff198b402dc04daa1b11e4007b9906b23 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -203,27 +203,46 @@ config_setting( visibility = ["//visibility:public"], ) -# TODO(jhseu): Enable on other platforms other than Linux. config_setting( - name = "with_jemalloc_linux_x86_64", - define_values = {"with_jemalloc": "true"}, - values = {"cpu": "k8"}, + name = "with_default_optimizations", + define_values = {"with_default_optimizations": "true"}, visibility = ["//visibility:public"], ) +# Features that are default ON are handled differently below. +# config_setting( - name = "with_jemalloc_linux_ppc64le", - define_values = {"with_jemalloc": "true"}, - values = {"cpu": "ppc"}, + name = "no_aws_support", + define_values = {"no_aws_support": "false"}, visibility = ["//visibility:public"], ) config_setting( - name = "with_default_optimizations", - define_values = {"with_default_optimizations": "true"}, + name = "no_gcp_support", + define_values = {"no_gcp_support": "false"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "no_hdfs_support", + define_values = {"no_hdfs_support": "false"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "no_ignite_support", + define_values = {"no_ignite_support": "false"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "no_kafka_support", + define_values = {"no_kafka_support": "false"}, visibility = ["//visibility:public"], ) +# Crosses between platforms and file system libraries not supported on those +# platforms due to limitations in nested select() statements. config_setting( name = "with_cuda_support_windows_override", define_values = {"using_cuda_nvcc": "true"}, @@ -259,30 +278,6 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting( - name = "with_jemalloc_linux_x86_64_dynamic", - define_values = { - "with_jemalloc": "true", - "framework_shared_object": "true", - }, - values = { - "cpu": "k8", - }, - visibility = ["//visibility:public"], -) - -config_setting( - name = "with_jemalloc_linux_ppc64le_dynamic", - define_values = { - "with_jemalloc": "true", - "framework_shared_object": "true", - }, - values = { - "cpu": "ppc", - }, - visibility = ["//visibility:public"], -) - config_setting( name = "using_cuda_clang", define_values = { diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 17e2e292eb19029d279bc12a8328edadf96f1bb8..56f5e6767ac68b1008c786e3b5a47b9b173ab9cb 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -6,11 +6,12 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", "tf_cc_test", - "tf_cuda_cc_test", "tf_copts", "tf_cuda_library", "tf_custom_op_library", + "tf_kernel_library", ) +load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") # ----------------------------------------------------------------------------- # Public targets @@ -197,9 +198,9 @@ tf_cuda_cc_test( size = "small", srcs = ["c_api_test.cc"], data = [ - ":test_op.so", "//tensorflow/cc/saved_model:saved_model_half_plus_two", ], + kernels = [":test_op_kernel"], linkopts = select({ "//tensorflow:darwin": ["-headerpad_max_install_names"], "//conditions:default": [], @@ -285,6 +286,16 @@ tf_custom_op_library( srcs = ["test_op.cc"], ) +tf_kernel_library( + name = "test_op_kernel", + srcs = ["test_op.cc"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + # ----------------------------------------------------------------------------- # Python API target diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 79811ceae57e0bddeb2a6f32bad7003e14e23422..1726db12fa62c5a3665de9fc306da38c1b7f0f9c 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -2770,6 +2770,9 @@ TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name, } string name_str(name, name_len); const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str); + if (api_def == nullptr) { + return nullptr; + } TF_Buffer* ret = TF_NewBuffer(); status->status = MessageToBuffer(*api_def, ret); diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 03516c39dc970aa23967107d3a0446da94669465..c4746b4990bc3bf80b749428f803056e552421c3 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb_text.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" @@ -195,12 +196,31 @@ TEST(CAPI, LibraryLoadFunctions) { TF_DeleteStatus(status); ASSERT_EQ(TF_OK, code) << status_msg; - // Test op list. - TF_Buffer op_list_buf = TF_GetOpList(lib); - tensorflow::OpList op_list; - EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length)); - ASSERT_EQ(op_list.op_size(), 1); - EXPECT_EQ("TestCApi", op_list.op(0).name()); + { + TF_Buffer* op_list_buffer = TF_GetAllOpList(); + tensorflow::OpList op_list; + op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length); + ASSERT_GE(op_list.op_size(), 1); + typedef tensorflow::protobuf::RepeatedPtrField OpDefs; + const OpDefs& ops = op_list.op(); + bool found = std::find_if(ops.begin(), ops.end(), + [](const tensorflow::OpDef& op_def) { + return op_def.name() == "TestCApi"; + }) != ops.end(); + EXPECT_TRUE(found); + TF_DeleteBuffer(op_list_buffer); + } + +#if !defined(TENSORFLOW_NO_SHARED_OBJECTS) + { + // Test op list. + TF_Buffer op_list_buf = TF_GetOpList(lib); + tensorflow::OpList op_list; + EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length)); + ASSERT_EQ(op_list.op_size(), 1); + EXPECT_EQ("TestCApi", op_list.op(0).name()); + } +#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS) TF_DeleteLibraryHandle(lib); } @@ -2335,9 +2355,9 @@ TEST(TestApiDef, TestCreateApiDef) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); - TF_Buffer op_list_buf = TF_GetOpList(lib); + TF_Buffer* op_list_buf = TF_GetAllOpList(); status = TF_NewStatus(); - auto* api_def_map = TF_NewApiDefMap(&op_list_buf, status); + auto* api_def_map = TF_NewApiDefMap(op_list_buf, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); @@ -2355,6 +2375,7 @@ TEST(TestApiDef, TestCreateApiDef) { TF_DeleteBuffer(api_def_buf); TF_DeleteApiDefMap(api_def_map); + TF_DeleteBuffer(op_list_buf); TF_DeleteLibraryHandle(lib); } @@ -2369,9 +2390,9 @@ TEST(TestApiDef, TestCreateApiDefWithOverwrites) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); - TF_Buffer op_list_buf = TF_GetOpList(lib); + TF_Buffer* op_list_buf = TF_GetAllOpList(); status = TF_NewStatus(); - auto* api_def_map = TF_NewApiDefMap(&op_list_buf, status); + auto* api_def_map = TF_NewApiDefMap(op_list_buf, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); @@ -2400,6 +2421,7 @@ TEST(TestApiDef, TestCreateApiDefWithOverwrites) { TF_DeleteBuffer(api_def_buf); TF_DeleteApiDefMap(api_def_map); + TF_DeleteBuffer(op_list_buf); TF_DeleteLibraryHandle(lib); } diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 0bf3d9542b72ecff916986ab809e8793b796d14c..3554ec0bf3202b54bfc38d67e51b89df19832302 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -578,6 +578,14 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, op->operation.MutableAttrs()->Set(attr_name, attr_value); } +void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name, + const char* data, size_t length) { + tensorflow::AttrValue attr_value; + tensorflow::NameAttrList* func = attr_value.mutable_func(); + func->set_name(data, length); + op->operation.MutableAttrs()->Set(attr_name, attr_value); +} + void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor, TF_Status* status) { tensorflow::Tensor t; diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 6323f8a053197bb7069acf2d43214fb78c36f436..b2454d872207e26feb3764671474a5d87c01f84d 100755 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -313,6 +313,9 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, const TFE_Op* value); +TF_CAPI_EXPORT void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name, + const char* data, size_t length); + TF_CAPI_EXPORT extern void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor, diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc index 5607c9dcb0bbec72b2f86def3dd4e6590d73197b..008f088c2dcdd7d9114103516a4702e47a55c6de 100644 --- a/tensorflow/c/eager/c_api_test_util.cc +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -99,8 +99,6 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { TFE_OpAddInput(op, b, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); - TFE_OpSetAttrBool(op, "transpose_a", 0); - TFE_OpSetAttrBool(op, "transpose_b", 0); TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a)); return op; diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 41b5b8ff36e16100e349cb909dc79d90fa4866b0..5ba55a203ff70cc64c07e96b5a869a1f11c9334e 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -130,7 +130,7 @@ class GradientTape { const string& op_type, std::vector& output_tensors, gtl::ArraySlice input_tensor_id, gtl::ArraySlice input_dtypes, - BackwardFunction* backward_function, + const std::function& backward_function_getter, const std::function& backward_function_deleter); void DeleteTrace(int64 tensor_id); @@ -206,10 +206,9 @@ void GradientTape::RecordOperation( const string& op_type, std::vector& output_tensors, gtl::ArraySlice input_tensor_id, gtl::ArraySlice input_dtypes, - BackwardFunction* backward_function, + const std::function& backward_function_getter, const std::function& backward_function_deleter) { if (!ShouldRecord(input_tensor_id, input_dtypes)) { - backward_function_deleter(backward_function); return; } std::vector ids; @@ -229,7 +228,7 @@ void GradientTape::RecordOperation( tensors.push_back(o); } op_tape_[op_id] = OpTapeEntry{ - op_type, std::move(tensors), ids, backward_function, + op_type, std::move(tensors), std::move(ids), backward_function_getter(), backward_function_deleter}; } diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index b587e63227708427e7fae47f8f4a7b524d963ed9..c18b07603ae3841d3581741ab5a43f2e8b628356 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -411,6 +411,7 @@ tf_cc_test( srcs = ["gradients/nn_grad_test.cc"], deps = [ ":cc_ops", + ":cc_ops_internal", ":grad_op_registry", ":grad_testutil", ":gradient_checker", @@ -453,11 +454,33 @@ tf_cc_test( ], ) +# Generates separate libraries for array_ops and math_ops to reduce the dependency count of targets that depend on only these tf_gen_op_wrappers_cc( - name = "cc_ops", + name = "math_ops", + api_def_srcs = ["//tensorflow/core/api_def:base_api_def"], + op_lib_names = [ + "math_ops", + ], + pkg = "//tensorflow/core", +) + +tf_gen_op_wrappers_cc( + name = "array_ops", api_def_srcs = ["//tensorflow/core/api_def:base_api_def"], op_lib_names = [ "array_ops", + ], + pkg = "//tensorflow/core", +) + +tf_gen_op_wrappers_cc( + name = "cc_ops", + api_def_srcs = ["//tensorflow/core/api_def:base_api_def"], + deps_internal = [ + ":array_ops_internal", + ":math_ops_internal", + ], + op_lib_names = [ "audio_ops", "candidate_sampling_ops", "control_flow_ops", @@ -468,7 +491,6 @@ tf_gen_op_wrappers_cc( "logging_ops", "lookup_ops", "manip_ops", - "math_ops", "nn_ops", "no_op", "parsing_ops", @@ -480,10 +502,21 @@ tf_gen_op_wrappers_cc( "user_ops", ], other_hdrs = [ + "ops/array_ops.h", "ops/const_op.h", + "ops/math_ops.h", "ops/standard_ops.h", ], + other_hdrs_internal = [ + "ops/array_ops_internal.h", + "ops/math_ops_internal.h", + ], pkg = "//tensorflow/core", + deps = [ + ":array_ops", + ":const_op", + ":math_ops", + ], ) tf_cc_test( diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index a32d1b1eb50fc715084f5ee663a732770db1883c..39593370d1c243e84dc5b6091724d1d404c102b0 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -853,11 +853,7 @@ void OpInfo::WriteClassDecl(WritableFile* h) const { } } - strings::StrAppend(&class_decl, "\n"); - - if (output_types.empty()) { - strings::StrAppend(&class_decl, " Operation operation;\n"); - } + strings::StrAppend(&class_decl, "\n Operation operation;\n"); for (int i = 0; i < output_types.size(); ++i) { strings::StrAppend(&class_decl, " ", output_types[i], " ", output_names[i], ";\n"); @@ -878,9 +874,11 @@ void OpInfo::GetOutput(string* out) const { string return_on_error = strings::StrCat("if (!", scope_str, ".ok()) return;"); + strings::StrAppend(out, " this->operation = Operation(ret);\n"); + // No outputs. if (graph_op_def.output_arg_size() == 0) { - strings::StrAppend(out, " this->operation = Operation(ret);\n return;\n"); + strings::StrAppend(out, " return;\n"); return; } if (graph_op_def.output_arg_size() == 1) { diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 7f6ac4cae78d8d6e118837fce9ae5270336cdc89..6abc9e268e3ac97379954a34017ddffa010db67f 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -62,7 +62,7 @@ Scope::Impl::Impl(const std::shared_ptr& graph, refiner_(refiner), scope_used_(nullptr), colocation_constraints_(), - disable_shape_inference_(false) {} + disable_shape_inference_(refiner_ == nullptr) {} Scope Scope::NewRootScope() { Graph* graph = new Graph(OpRegistry::Global()); @@ -94,6 +94,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -110,6 +111,7 @@ Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -132,6 +134,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ControlDeps, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -163,6 +166,7 @@ Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -178,6 +182,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ExitOnError) exit_on_error_(true), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -194,6 +199,7 @@ Scope::Impl::Impl(const Scope& other, Tags::KernelLabel, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(kernel_label), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -210,12 +216,30 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_( clear_colocations ? std::unordered_set() : other.impl()->GetColocationConstraints(colocate_with_op)), disable_shape_inference_(other.impl()->disable_shape_inference_) {} +Scope::Impl::Impl(const Scope& other, Tags::AssignedDevice, + const string& assigned_device) + : graph_(other.impl()->graph_), + status_(other.impl()->status_), + name_map_(other.impl()->name_map_), + refiner_(other.impl()->refiner_), + scope_used_(other.impl()->scope_used_), + control_deps_(other.impl()->control_deps_), + name_(other.impl()->name_), + op_name_(other.impl()->op_name_), + exit_on_error_(other.impl()->exit_on_error_), + kernel_label_(other.impl()->kernel_label_), + device_(other.impl()->device_), + assigned_device_(assigned_device), + colocation_constraints_(other.impl()->colocation_constraints_), + disable_shape_inference_(other.impl()->disable_shape_inference_) {} + std::unordered_set Scope::Impl::GetColocationConstraints( const Operation& colocate_with_op) const { std::unordered_set current_constraints(colocation_constraints_); @@ -299,6 +323,9 @@ void Scope::UpdateBuilder(NodeBuilder* builder) const { if (!impl()->device_.empty()) { builder->Device(impl()->device_); } + if (!impl()->assigned_device_.empty()) { + builder->AssignedDevice(impl()->assigned_device_); + } } string Scope::Impl::GetUniqueName(const string& prefix, @@ -394,6 +421,10 @@ Scope Scope::WithDevice(const string& device) const { return Scope(new Impl(*this, Impl::Tags::Device(), device)); } +Scope Scope::WithAssignedDevice(const string& assigned_device) const { + return Scope(new Impl(*this, Impl::Tags::AssignedDevice(), assigned_device)); +} + Scope Scope::ColocateWith(const Operation& op) const { return Scope(new Impl(*this, Impl::Tags::Colocate(), op, /* clear_colocations */ false)); diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h index 30c32bd44b0f22d6b29dd3836d431807d0216818..e307d8989b6647dfac8d2691ed2171c86b7f3a7c 100644 --- a/tensorflow/cc/framework/scope.h +++ b/tensorflow/cc/framework/scope.h @@ -133,6 +133,10 @@ class Scope { /// the device field set to 'device'. Scope WithDevice(const string& device) const; + /// Returns a new scope. All ops created within the returned scope will have + /// their assigned device set to `assigned_device`. + Scope WithAssignedDevice(const string& assigned_device) const; + /// Return a new scope. All ops created within the returned scope will be /// co-located on the device where op is placed. /// NOTE: This function is intended to be use internal libraries only for diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h index 58adaef2e942a7fa6b0ce8d5534ac3e2fd380580..514e02e84146b6d95147d83182e5d9a07509cfa1 100644 --- a/tensorflow/cc/framework/scope_internal.h +++ b/tensorflow/cc/framework/scope_internal.h @@ -26,6 +26,8 @@ class ShapeRefiner; // graph, status, name_map, and refiner. // This is intended to enable the C API (which are used by other language // bindings) to create a Scope and access C++ functionality (i.e. gradients). +// +// Shape inference is disabled if `refiner` is nullptr. Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner); class Scope::Impl { @@ -58,6 +60,7 @@ class Scope::Impl { enum class ExitOnError; enum class KernelLabel; enum class Colocate; + enum class AssignedDevice; }; Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner, @@ -74,6 +77,7 @@ class Scope::Impl { Impl(const Scope& other, Tags::KernelLabel, const string& kernel_label); Impl(const Scope& other, Tags::Colocate, const Operation& colocate_with_op, bool clear_colocations); + Impl(const Scope& other, Tags::AssignedDevice, const string& assigned_device); std::unordered_set GetColocationConstraints( const Operation& colocate_with_op) const; @@ -107,6 +111,7 @@ class Scope::Impl { const bool exit_on_error_ = false; const string kernel_label_ = ""; const string device_ = ""; + const string assigned_device_ = ""; const std::unordered_set colocation_constraints_; // If true, Scope::DoShapeInference() always returns Status:OK(). diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 588e96cb196189780037f66266484962ba0385e4..2a32a2ed6f7862a29f4ce3d1aba5fdbc86adc670 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -143,6 +143,33 @@ Status Relu6GradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Relu6", Relu6GradHelper); +Status LeakyReluGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + float alpha; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha)); + internal::LeakyReluGrad::Attrs attrs; + auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(0), + attrs.Alpha(alpha)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("LeakyRelu", LeakyReluGradHelper); + +Status LeakyReluGradGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + float alpha; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha)); + internal::LeakyReluGrad::Attrs attrs; + auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(1), + attrs.Alpha(alpha)); + grad_outputs->push_back(dx); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("LeakyReluGrad", LeakyReluGradGradHelper); + Status EluGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index aa72cf7ba2a958f54d50b59f0edaefb27edf0e86..f5a09e09dcda3e06c71d44d5fa5a1b121a9ade58 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/cc/framework/gradient_checker.h" #include "tensorflow/cc/framework/testutil.h" #include "tensorflow/cc/gradients/grad_testutil.h" +#include "tensorflow/cc/ops/nn_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -160,6 +161,32 @@ TEST_F(NNGradTest, Relu6Grad) { RunTest(x, x_init_value, y, shape); } +TEST_F(NNGradTest, LeakyReluGrad) { + TensorShape shape({5, 2}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = ops::internal::LeakyRelu(scope_, x); + // Avoid input values where Leaky ReLU gradient is not well defined (around + // zero). + Tensor x_init_value = test::AsTensor( + {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f}, + {5, 2}); + RunTest(x, x_init_value, y, shape); +} + +TEST_F(NNGradTest, LeakyReluGradGrad) { + TensorShape shape({5, 2}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + // Avoid input values where Leaky ReLU gradient is not well defined (around + // zero). + Tensor x_init_value = test::AsTensor( + {2.3f, 1.9f, 1.5f, 1.1f, 0.7f, 0.3f, -0.1f, -0.5f, -0.9f, -1.3f}, {5, 2}); + Tensor features = test::AsTensor( + {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f}, + {5, 2}); + auto y = ops::internal::LeakyReluGrad(scope_, x, features); + RunTest(x, x_init_value, y, shape); +} + TEST_F(NNGradTest, EluGrad) { TensorShape shape({5, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 5bf4af1014abd84d668947ee8aeff09578d4bff4..311313b8f2318f6679678104bb55e0b5911fc2c5 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -258,6 +258,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:variable_ops", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -323,6 +324,7 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -357,6 +359,79 @@ tf_cc_test( ], ) +cc_library( + name = "shape_inference", + srcs = ["shape_inference.cc"], + hdrs = ["shape_inference.h"], + deps = [ + ":shape_inference_helpers", + "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "test_util", + testonly = 1, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + deps = [ + ":shape_inference", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "shape_inference_test", + srcs = ["shape_inference_test.cc"], + deps = [ + ":shape_inference", + ":test_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:ops", + "//tensorflow/core:framework", + "//tensorflow/core:ops", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/kernels:constant_op", + ], +) + +cc_library( + name = "encapsulate_util", + srcs = ["encapsulate_util.cc"], + hdrs = ["encapsulate_util.h"], + deps = [ + ":shape_inference", + "//tensorflow/core:graph", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + +tf_cc_test( + name = "encapsulate_util_test", + srcs = ["encapsulate_util_test.cc"], + deps = [ + ":encapsulate_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + "//tensorflow/core:framework", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "compilation_passes", srcs = [ @@ -383,12 +458,17 @@ cc_library( ":shape_inference_helpers", ":union_find", ":xla_cluster_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope_internal", "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/compiler/jit/legacy_flags:build_xla_ops_pass_flags", "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", @@ -400,6 +480,8 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -471,6 +553,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -489,6 +572,7 @@ tf_cc_test( ":compilation_passes", ":node_matchers", ":xla_cluster_util", + ":xla_cpu_device", ":xla_gpu_device", "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", @@ -500,6 +584,7 @@ tf_cc_test( "//tensorflow/compiler/tf2xla:test_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", + "//tensorflow/compiler/tf2xla/cc:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -509,6 +594,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/grappler/optimizers/data:graph_utils", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -610,6 +696,7 @@ cc_library( deps = [ "//tensorflow/cc:ops", "//tensorflow/compiler/xla:test", + "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/algorithm:container", @@ -625,6 +712,7 @@ tf_cc_test( deps = [ ":node_matchers", "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:ops", "//tensorflow/core:ops", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index 9e3fd93cda1bfadfe968ffc0433cfe50ca2d7670..054f31ba3352b2215e6b0448c8ec8a70cb98b8e5 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -14,8 +14,18 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/build_xla_ops_pass.h" +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope_internal.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/control_flow_ops.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/function.h" @@ -31,152 +41,298 @@ limitations under the License. #include "tensorflow/core/public/version.h" namespace tensorflow { +namespace { +void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) { + std::vector out_edges(old_node->out_edges().begin(), + old_node->out_edges().end()); + for (const Edge* edge : out_edges) { + // TODO(sanjoy): This does not update NodeDef inputs. To be able to update + // NodeDef inputs we first need to fix encapsulate_subgraphs_pass to fix up + // the NodeDef inputs to the function call nodes. + g->AddEdge(new_node, edge->src_output(), edge->dst(), edge->dst_input()); + g->RemoveEdge(edge); + } +} -static Status BuildXlaCompileNode( - const string& nodename, const string& function_name, - const AttrValueMap& function_attr, const string& device_name, - const DataTypeVector& constant_dtypes, int num_resources, - const DataTypeVector& arg_dtypes, Graph* graph, Node** node) { - NodeDef def; - def.set_name(graph->NewName(nodename)); - def.set_op("_XlaCompile"); - def.set_device(device_name); - AddNodeAttr("Tconstants", constant_dtypes, &def); - AddNodeAttr("Targs", arg_dtypes, &def); - AddNodeAttr("Nresources", num_resources, &def); - NameAttrList function; - function.set_name(function_name); - *function.mutable_attr() = function_attr; - AddNodeAttr("function", function, &def); +// Returns a data value that is dead iff `control` is dead. +Output ControlToData(const Scope& scope, Node* control) { + Output data = ops::Const(scope.WithOpName("ctrl_as_data"), + Tensor(DT_BOOL, TensorShape({0}))); + scope.graph()->AddControlEdge(control, data.node()); + return Output(data.node()); +} - Status status; - *node = graph->AddNode(def, &status); - return status; +// Returns an operation that can be control-depended on that is dead iff `data` +// is dead. +Operation DataToControl(const Scope& scope, Output data) { + return Operation( + ops::Identity(scope.WithOpName("data_as_ctrl"), data).node()); } -static Status BuildXlaRunNode(const string& nodename, const string& device_name, - const DataTypeVector& arg_dtypes, - const DataTypeVector& result_dtypes, Graph* graph, - Node** node) { - NodeDef def; - def.set_name(graph->NewName(nodename)); - def.set_op("_XlaRun"); - def.set_device(device_name); - AddNodeAttr("Targs", arg_dtypes, &def); - AddNodeAttr("Tresults", result_dtypes, &def); +// Replaces each outgoing edge from `old_node` with a merge node that merges in +// the corresponding output from `new_node`. +void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node) { + if (!s.status().ok()) { + return; + } - Status status; - *node = graph->AddNode(def, &status); - return status; + std::vector merged_outputs(old_node->num_outputs(), Output(nullptr)); + + std::vector data_edges; + absl::c_copy_if(old_node->out_edges(), std::back_inserter(data_edges), + [](const Edge* e) { return !e->IsControlEdge(); }); + + for (const Edge* e : data_edges) { + int oidx = e->src_output(); + Output merged_output = merged_outputs[oidx]; + if (merged_output.node() == nullptr) { + ops::Merge merge_op(s.WithOpName(absl::StrCat("merge_oidx_", oidx)), + {Output(old_node, oidx), Output(new_node, oidx)}); + merged_output = merged_outputs[oidx] = merge_op.output; + } + + Node* dst = e->dst(); + int dst_idx = e->dst_input(); + + s.graph()->RemoveEdge(e); + s.graph()->AddEdge(merged_output.node(), merged_output.index(), dst, + dst_idx); + } } -static Status GetXlaAttrs(Node* node, int* num_constant_args, - int* num_resource_args, DataTypeVector* const_dtypes, - DataTypeVector* arg_dtypes) { +// Replaces each control successor of `old_node` to execute whenever either +// `old_node` or `new_node` is executed. +void MergeOutgoingControlEdges(const Scope& s, Node* old_node, Node* new_node) { + if (!s.status().ok()) { + return; + } + + std::vector ctrl_edges; + absl::c_copy_if(old_node->out_edges(), std::back_inserter(ctrl_edges), + [](const Edge* e) { return e->IsControlEdge(); }); + + if (ctrl_edges.empty()) { + return; + } + + // We can't merge control edges directly so we instead first "convert" them to + // normal values that can be merged, merge the values and then "convert" the + // merged value back into control. + // + // NB! We need to copy out the outgoing control edges before constructing + // old_ctrl_as_data otherwise the control edge from old_node to the constant + // in ControlToData will be present in ctrl_edges. + + Output old_ctrl_as_data = ControlToData(s, old_node); + Output new_ctrl_as_data = ControlToData(s, new_node); + + ops::Merge ctrl_merge_as_data(s.WithOpName("ctrl_merge"), + {old_ctrl_as_data, new_ctrl_as_data}); + Operation ctrl_merge = DataToControl(s, ctrl_merge_as_data.output); + + for (const Edge* e : ctrl_edges) { + s.graph()->AddControlEdge(ctrl_merge.node(), e->dst()); + s.graph()->RemoveControlEdge(e); + } +} + +struct XlaClusterInfo { + std::vector constant_inputs; + std::vector non_constant_inputs; + std::vector resource_inputs; + NameAttrList function; +}; + +Output IncomingEdgeAsOutput(const Edge* e) { + return Output(e->src(), e->src_output()); +} + +Status GetXlaClusterInfo(Node* n, XlaClusterInfo* result) { + int num_constant_inputs, num_resource_inputs; TF_RETURN_IF_ERROR( - GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, num_constant_args)); + GetNodeAttr(n->attrs(), kXlaNumConstantArgsAttr, &num_constant_inputs)); TF_RETURN_IF_ERROR( - GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, num_resource_args)); + GetNodeAttr(n->attrs(), kXlaNumResourceArgsAttr, &num_resource_inputs)); - if (*num_constant_args < 0 || *num_resource_args < 0 || - *num_constant_args + *num_resource_args > node->num_inputs()) { + if (num_constant_inputs < 0 || num_resource_inputs < 0 || + num_constant_inputs + num_resource_inputs > n->num_inputs()) { return errors::InvalidArgument( "Invalid number of constant/resource arguments to XLA kernel."); } - const int num_nonconst_args = - node->num_inputs() - *num_constant_args - *num_resource_args; + int num_non_constant_inputs = + n->num_inputs() - num_constant_inputs - num_resource_inputs; + + std::vector input_edges_vector; + TF_RETURN_IF_ERROR(n->input_edges(&input_edges_vector)); + absl::Span input_edges(input_edges_vector); + + absl::c_transform(input_edges.subspan(0, num_constant_inputs), + std::back_inserter(result->constant_inputs), + IncomingEdgeAsOutput); - const DataTypeVector& input_types = node->input_types(); - std::copy(input_types.begin(), input_types.begin() + *num_constant_args, - std::back_inserter(*const_dtypes)); - std::copy(input_types.begin() + *num_constant_args, - input_types.begin() + *num_constant_args + num_nonconst_args, - std::back_inserter(*arg_dtypes)); + absl::c_transform( + input_edges.subspan(num_constant_inputs, num_non_constant_inputs), + std::back_inserter(result->non_constant_inputs), IncomingEdgeAsOutput); + + absl::c_transform( + input_edges.subspan(num_constant_inputs + num_non_constant_inputs, + num_resource_inputs), + std::back_inserter(result->resource_inputs), IncomingEdgeAsOutput); + + result->function.set_name(n->type_string()); + *result->function.mutable_attr() = n->def().attr(); return Status::OK(); } -static void CopyIncomingEdges(Graph* g, Node* old_node, Node* new_node, - int prefix_to_ignore) { - for (const Edge* edge : old_node->in_edges()) { - if (edge->IsControlEdge()) { - g->AddControlEdge(edge->src(), new_node); - } else if (edge->dst_input() >= prefix_to_ignore) { - g->AddEdge(edge->src(), edge->src_output(), new_node, - edge->dst_input() - prefix_to_ignore); +Status CopyIncomingControlEdges(Graph* g, Node* from, Node* to) { + for (const Edge* e : from->in_edges()) { + if (e->IsControlEdge()) { + g->AddControlEdge(e->src(), to); } } + + return Status::OK(); } -static void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) { - std::vector out_edges(old_node->out_edges().begin(), - old_node->out_edges().end()); - for (const Edge* edge : out_edges) { - // TODO(sanjoy): This does not update NodeDef inputs. - g->AddEdge(new_node, edge->src_output(), edge->dst(), edge->dst_input()); - g->RemoveEdge(edge); +void RemoveAllIncomingControlEdges(Graph* g, Node* n) { + std::vector incoming_ctrl_edges; + absl::c_copy_if(n->in_edges(), std::back_inserter(incoming_ctrl_edges), + [](const Edge* e) { return e->IsControlEdge(); }); + for (const Edge* e : incoming_ctrl_edges) { + g->RemoveControlEdge(e); + } +} + +// Returns true (into `result`) if `node` must be compiled. +Status NodeRequiresCompilation(Node* n, bool* result) { + DeviceType device_type(""); + TF_RETURN_IF_ERROR( + DeviceToDeviceType(n->assigned_device_name(), &device_type)); + const XlaOpRegistry::DeviceRegistration* registration = nullptr; + if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { + return errors::Internal("Could not find compilation device ", + device_type.type()); } + *result = registration->requires_compilation; + return Status::OK(); } -static Status ReplaceNodeWithXlaCompileAndRun(Graph* g, Node* n) { - int num_constant_args, num_resource_args; - DataTypeVector const_dtypes; - DataTypeVector arg_dtypes; +Status ReplaceNodeWithXlaCompileAndXlaRun( + const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled, + Graph* g, Node* n) { + bool requires_compilation; + TF_RETURN_IF_ERROR(NodeRequiresCompilation(n, &requires_compilation)); + if (!lazy_compilation_enabled) { + requires_compilation = true; + } + + Status status; + Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr) + .NewSubScope(n->name()) + .WithDevice(n->requested_device()) + .WithAssignedDevice(n->assigned_device_name()); - TF_RETURN_IF_ERROR(GetXlaAttrs(n, &num_constant_args, &num_resource_args, - &const_dtypes, &arg_dtypes)); + XlaClusterInfo cluster_info; + TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info)); - Node *compile_node, *run_node; + ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"), + /*constants=*/cluster_info.constant_inputs, + /*args=*/cluster_info.non_constant_inputs, + /*resources=*/cluster_info.resource_inputs, + /*must_compile=*/requires_compilation, + cluster_info.function); + TF_RETURN_IF_ERROR( + CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node())); - TF_RETURN_IF_ERROR(BuildXlaCompileNode( - n->name(), n->type_string(), n->def().attr(), n->requested_device(), - const_dtypes, num_resource_args, arg_dtypes, g, &compile_node)); + if (requires_compilation) { + // "Strict" compilation: every _XlaCompile invocation must compile the + // cluster. + std::vector xla_run_args = cluster_info.non_constant_inputs; + absl::c_copy(cluster_info.resource_inputs, + std::back_inserter(xla_run_args)); + ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args, + xla_compile.key, n->output_types()); - DataTypeVector arg_dtypes_with_resources = arg_dtypes; - for (int i = 0; i < num_resource_args; i++) { - arg_dtypes_with_resources.push_back(DT_RESOURCE); - } + MoveOutgoingEdges(g, /*old_node=*/n, + /*new_node=*/xla_run.operation.node()); + g->RemoveNode(n); + } else { + // "Lazy" compilation: an _XlaCompile invocation may decide not to compile + // the cluster based on profitability heuristics. - TF_RETURN_IF_ERROR(BuildXlaRunNode(n->name(), n->requested_device(), - arg_dtypes_with_resources, - n->output_types(), g, &run_node)); + // We generate the following graph: + // + // (use_tf_call, use_xla_run) = + // Switch(pred=xla_compile.compilation_successful, + // value=xla_compile.key) + // + // tf_call_outputs = cluster_N(..., ^use_tf_call) + // xla_run_outputs = _XlaRun(..., key=use_xla_run) + // outputs = Merge(tf_call_outputs, xla_run_outputs). + ops::Switch s(root.WithOpName("predicated_compilation_key"), + xla_compile.key, xla_compile.compilation_successful); + Output predicated_compilation_key = s.output_true; + Output inverse_predicated_compilation_key = s.output_false; - compile_node->set_assigned_device_name(n->assigned_device_name()); - run_node->set_assigned_device_name(n->assigned_device_name()); + std::vector xla_run_args = cluster_info.non_constant_inputs; + absl::c_copy(cluster_info.resource_inputs, + std::back_inserter(xla_run_args)); + ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args, + predicated_compilation_key, n->output_types()); - CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/compile_node, - /*prefix_to_ignore=*/0); - CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/run_node, - /*prefix_to_ignore=*/num_constant_args); + MergeOutgoingControlEdges(root, /*old_node=*/n, + /*new_node=*/xla_run.operation.node()); - // The compilation_key output. - g->AddEdge(compile_node, 0, run_node, n->num_inputs() - num_constant_args); + MergeOutgoingDataEdges(root, /*old_node=*/n, + /*new_node=*/xla_run.operation.node()); - MoveOutgoingEdges(g, /*old_node=*/n, /*new_node=*/run_node); - g->RemoveNode(n); + TF_RETURN_IF_ERROR(root.status()); + + // We already have a TensorFlow function call into the cluster -- the + // original node we set out to rewrite. We just wire in the correct control + // deps and we're done. + RemoveAllIncomingControlEdges(g, n); + g->AddControlEdge( + DataToControl(root, inverse_predicated_compilation_key).node(), n); + n->ClearAttr(kXlaCompiledKernelAttr); + } return Status::OK(); } +} // namespace Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) { Graph* graph = options.graph->get(); - for (Node* n : graph->op_nodes()) { - // In all cases, only try to compile computational nodes. - if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) { - continue; - } + // Copy out the nodes we want to rewrite to avoid modifying the graph while we + // iterate on graph->op_nodes(). + std::vector xla_compiled_kernels; + absl::c_copy_if(graph->op_nodes(), std::back_inserter(xla_compiled_kernels), + [](const Node* n) { + if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) { + return false; + } - // Only compile nodes that are marked for compilation by the - // compilation-marking pass (via 'attr_name'). - if (IsXlaCompiledKernel(*n)) { - TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndRun(graph, n)); - } + // Only compile nodes that are marked for compilation by the + // compilation-marking pass (via 'attr_name'). + return IsXlaCompiledKernel(*n); + }); + + bool lazy_compilation_enabled = enable_lazy_compilation_ + ? *enable_lazy_compilation_ + : legacy_flags::GetBuildXlaOpsPassFlags() + .tf_xla_enable_lazy_compilation; + + for (Node* n : xla_compiled_kernels) { + TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun( + *options.flib_def, lazy_compilation_enabled, graph, n)); } if (VLOG_IS_ON(1)) { dump_graph::DumpGraphToFile("build_xla_ops", *graph, options.flib_def); } + return Status::OK(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.h b/tensorflow/compiler/jit/build_xla_ops_pass.h index 1dd38fa95186dfbe458166caa23a131fbe3c9510..58f7c4b3a0d1472f602e8234f9f08c23dfe78a34 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.h +++ b/tensorflow/compiler/jit/build_xla_ops_pass.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_ #define TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_ +#include "absl/types/optional.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/lib/core/status.h" @@ -25,7 +26,17 @@ namespace tensorflow { // executes (using XLA) TF function calls marked with "_XlaCompiledKernel". class BuildXlaOpsPass : public GraphOptimizationPass { public: + // If enable_lazy_compilation is not nullopt then *enable_lazy_compilation + // overrides --tf_xla_enable_lazy_compilation flag in deciding whether lazy + // compilation is enabled. + explicit BuildXlaOpsPass( + absl::optional enable_lazy_compilation = absl::nullopt) + : enable_lazy_compilation_(enable_lazy_compilation) {} + Status Run(const GraphOptimizationPassOptions& options) override; + + private: + absl::optional enable_lazy_compilation_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc index b7cb4506b9a372f26839d1fbce4754ff720ffee4..11df946cc186660242574c2644463a26ead44f1f 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc @@ -22,18 +22,44 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/node_matchers.h" +#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" namespace tensorflow { namespace { +class BuildXlaOpsTest : public ::testing::Test { + protected: + void SetUp() override { + // This is needed to register the XLA_* devices. + CHECK(DeviceFactory::AddDevices( + SessionOptions(), "/job:localhost/replica:0/task:0", &devices_) + .ok()); + } + + void TearDown() override { + for (Device* device : devices_) { + delete device; + } + } + + private: + std::vector devices_; +}; + using ::tensorflow::testing::FindNodeByName; +using ::tensorflow::testing::matchers::Attr; using ::tensorflow::testing::matchers::CtrlDeps; +using ::tensorflow::testing::matchers::Inputs; using ::tensorflow::testing::matchers::NodeWith; using ::tensorflow::testing::matchers::Op; +using ::tensorflow::testing::matchers::Out; +using ::testing::_; Status BuildXlaOps(const Scope& s, std::unique_ptr* result) { auto graph = absl::make_unique(OpRegistry::Global()); @@ -42,42 +68,56 @@ Status BuildXlaOps(const Scope& s, std::unique_ptr* result) { // Assign all nodes to the CPU device. static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; for (Node* n : graph->nodes()) { - if (n->assigned_device_name().empty()) { + if (n->requested_device().empty()) { n->set_assigned_device_name(kCpuDevice); + } else { + n->set_assigned_device_name(n->requested_device()); } } GraphOptimizationPassOptions opt_options; opt_options.graph = &graph; - BuildXlaOpsPass pass; + BuildXlaOpsPass pass(/*enable_lazy_compilation=*/true); TF_RETURN_IF_ERROR(pass.Run(opt_options)); + VLOG(3) << graph->ToGraphDefDebug().DebugString(); *result = std::move(graph); return Status::OK(); } Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, - const string& node_name, Node** result) { + const string& node_name, int num_constant_args, + int num_resource_args, Node** result) { NodeDef call_node; call_node.set_name(node_name); call_node.set_op(callee_name); AddNodeAttr(kXlaCompiledKernelAttr, true, &call_node); - AddNodeAttr(kXlaNumConstantArgsAttr, 0, &call_node); - AddNodeAttr(kXlaNumResourceArgsAttr, 0, &call_node); + AddNodeAttr(kXlaNumConstantArgsAttr, num_constant_args, &call_node); + AddNodeAttr(kXlaNumResourceArgsAttr, num_resource_args, &call_node); Status s; *result = graph->AddNode(call_node, &s); return s; } -Node* MakeWrite(const Scope& scope, const string& id) { - Output var_handle = - ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); - Output value_to_write = - ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f); - ops::AssignVariableOp assign_op(scope.WithOpName("Assignee" + id), var_handle, - value_to_write); +Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, + const string& node_name, Node** result) { + return MakeXlaCompiledKernel(graph, callee_name, node_name, + /*num_constant_args=*/0, /*num_resource_args=*/0, + result); +} + +Node* MakeWrite(const Scope& scope, Output value_to_write, const string& id) { + Output var_handle = ops::VarHandleOp(scope.WithOpName("Var_" + id), DT_FLOAT, + TensorShape({})); + ops::AssignVariableOp assign_op(scope.WithOpName("Assignee_" + id), + var_handle, value_to_write); return assign_op.operation.node(); } +Node* MakeWrite(const Scope& scope, const string& id) { + return MakeWrite( + scope, ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f), id); +} + FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) { FunctionDefLibrary flib_def; FunctionDef func = FunctionDefHelper::Create( @@ -89,14 +129,16 @@ FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) { return flib_def; } -TEST(BuildXlaOps, ControlDepsPreserved) { - Scope root = Scope::NewRootScope().ExitOnError(); +TEST_F(BuildXlaOpsTest, ControlDepsPreserved) { + const char* kXlaDeviceName = "/job:worker/replica:0/task:0/device:XLA_CPU:0"; + Scope root = Scope::NewRootScope().WithDevice(kXlaDeviceName).ExitOnError(); FunctionDefLibrary flib_def = CreateFunctionDefLibWithConstFunction("cluster_0"); TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); Node* call; TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call)); + call->set_requested_device(kXlaDeviceName); Node* write_op = MakeWrite(root, "write"); root.graph()->AddControlEdge(call, write_op); @@ -108,5 +150,85 @@ TEST(BuildXlaOps, ControlDepsPreserved) { EXPECT_THAT(write_op_new, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun"))))); } +TEST_F(BuildXlaOpsTest, CleanFailureOnBogusAttr) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("cluster_0"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* call; + TF_ASSERT_OK( + MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", 100, 100, &call)); + + Node* write_op = MakeWrite(root, "write"); + root.graph()->AddControlEdge(call, write_op); + + std::unique_ptr graph; + Status failure_status = BuildXlaOps(root, &graph); + ASSERT_FALSE(failure_status.ok()); + EXPECT_EQ(failure_status.code(), error::INVALID_ARGUMENT); +} + +TEST_F(BuildXlaOpsTest, OnNonXlaDevice) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("cluster_0"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* call; + TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call)); + TF_ASSERT_OK(root.DoShapeInference(call)); + + Node* write_op = MakeWrite(root, Output(call), "write_result"); + + auto xla_compile = NodeWith(Op("_XlaCompile"), Attr("must_compile", false)); + auto predicated_compilation_key = + NodeWith(Op("Switch"), Inputs(Out(0, xla_compile), Out(1, xla_compile))); + auto xla_run = + NodeWith(Op("_XlaRun"), Inputs(Out(1, predicated_compilation_key))); + auto tf_call = + NodeWith(Op("cluster_0"), + CtrlDeps(NodeWith(Op("Identity"), + Inputs(Out(0, predicated_compilation_key))))); + auto merge = NodeWith(Op("Merge"), Inputs(Out(tf_call), Out(xla_run))); + auto assign_var = NodeWith(Op("AssignVariableOp"), Inputs(_, Out(merge))); + + std::unique_ptr graph; + TF_ASSERT_OK(BuildXlaOps(root, &graph)); + + Node* write_op_new = FindNodeByName(graph.get(), write_op->name()); + ASSERT_NE(write_op_new, nullptr); + EXPECT_THAT(write_op_new, assign_var); +} + +TEST_F(BuildXlaOpsTest, OnXlaDevice) { + const char* kXlaDeviceName = "/job:worker/replica:0/task:0/device:XLA_CPU:0"; + Scope root = Scope::NewRootScope().WithDevice(kXlaDeviceName).ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("cluster_0"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* call; + TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call)); + call->set_requested_device(kXlaDeviceName); + TF_ASSERT_OK(root.DoShapeInference(call)); + + Node* write_op = MakeWrite(root, Output(call), "write_result"); + + std::unique_ptr graph; + TF_ASSERT_OK(BuildXlaOps(root, &graph)); + + auto xla_op = + NodeWith(Op("_XlaRun"), Inputs(Out(NodeWith(Op("_XlaCompile"))))); + auto assign_var = + NodeWith(Op("AssignVariableOp"), Inputs(Out(NodeWith()), Out(xla_op))); + + Node* write_op_new = FindNodeByName(graph.get(), write_op->name()); + ASSERT_NE(write_op_new, nullptr); + EXPECT_THAT(write_op_new, assign_var); +} } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 9128b48da3fe9dd3d85d146e16c153c1b3bebf4c..b7ae7fbeb3912882368dc828e8d6fcd50735b04e 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -14,11 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/deadness_analysis.h" +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/deadness_analysis_internal.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/hash/hash.h" // ALGORITHM OVERVIEW @@ -296,7 +299,7 @@ class SymbolPredicate : public Predicate { template /*static*/ void Predicate::Visit(Predicate* p, const FunctionTy& func) { - gtl::FlatSet visited; + absl::flat_hash_set visited; std::vector stack; stack.push_back(p); @@ -383,6 +386,8 @@ class PredicateFactory { } Predicate* MakeAndOrImpl(absl::Span operands, bool is_and); + Predicate* MakeInternedAndOr(std::vector simplified_ops, + Predicate::Kind pred_kind); // Predicate instances are interned, meaning that there is only a single // instance of a Predicate object with a given content. This makes checking @@ -417,24 +422,53 @@ class PredicateFactory { } }; - gtl::FlatMap, - HashSignatureForAndOr> + absl::flat_hash_map, + HashSignatureForAndOr> interned_and_or_instances_; - gtl::FlatMap> + absl::flat_hash_map> interned_not_instances_; - gtl::FlatMap> + absl::flat_hash_map> interned_and_rec_instances_; - gtl::FlatMap, - HashSignatureForSymbol> + absl::flat_hash_map, + HashSignatureForSymbol> interned_symbol_instances_; }; +Predicate* PredicateFactory::MakeInternedAndOr( + std::vector simplified_ops, Predicate::Kind pred_kind) { + std::stable_sort( + simplified_ops.begin(), simplified_ops.end(), + [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); }); + + auto it = interned_and_or_instances_.find({pred_kind, simplified_ops}); + if (it != interned_and_or_instances_.end()) { + return it->second.get(); + } + + simplified_ops.shrink_to_fit(); + // NB! Because we'll use a non-owning reference to simplified_ops in the + // key for interned_and_or_instances_ we need to be careful to std::move() + // it all the way through. + absl::Span operands_slice = simplified_ops; + std::unique_ptr new_pred = + pred_kind == Predicate::Kind::kAnd + ? Make(std::move(simplified_ops)) + : Make(std::move(simplified_ops)); + + Predicate* new_pred_ptr = new_pred.get(); + interned_and_or_instances_.emplace( + SignatureForAndOr(pred_kind, operands_slice), std::move(new_pred)); + return new_pred_ptr; +} + // Common code to create AndPredicate or OrPredicate instances. Predicate* PredicateFactory::MakeAndOrImpl( absl::Span operands, bool is_and) { Predicate::Kind pred_kind = is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr; - gtl::FlatSet simplified_ops_set; + Predicate::Kind other_pred_kind = + is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd; + absl::flat_hash_set simplified_ops_set; std::vector simplified_ops; for (Predicate* op : operands) { // Simplify A&A => A and A|A => A. @@ -459,7 +493,7 @@ Predicate* PredicateFactory::MakeAndOrImpl( } // Simplify "A&~A=>False" and "A|~A=>True". - gtl::FlatSet negated_ops; + absl::flat_hash_set negated_ops; for (Predicate* op : simplified_ops) { if (op->kind() == Predicate::Kind::kNot) { negated_ops.insert(dynamic_cast(*op).operand()); @@ -472,30 +506,63 @@ Predicate* PredicateFactory::MakeAndOrImpl( } } - std::stable_sort( - simplified_ops.begin(), simplified_ops.end(), - [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); }); + // If all ops contain the same subop, then factor it out thanks to the + // distributive property. Such as: + // - (A & B) | (A & C) | (A & D) => A & (B | C | D) + // - (A | B) & (A | C) & (A | D) => A | (B & C & D) + // + // First find any predicates contained in all subops. + std::vector common_inner_operands; + absl::flat_hash_set common_inner_operands_set; + for (Predicate* op : simplified_ops) { + if (op->kind() != other_pred_kind) { + common_inner_operands.clear(); + break; + } - auto it = interned_and_or_instances_.find({pred_kind, simplified_ops}); - if (it == interned_and_or_instances_.end()) { - simplified_ops.shrink_to_fit(); - // NB! Because we'll use a non-owning reference to simplified_ops in the - // key for interned_and_or_instances_ we need to be careful to std::move() - // it all the way through. - absl::Span operands_slice = simplified_ops; - std::unique_ptr new_pred = - is_and ? Make(std::move(simplified_ops)) - : Make(std::move(simplified_ops)); + if (common_inner_operands.empty()) { + common_inner_operands.insert(common_inner_operands.end(), + op->GetOperands().begin(), + op->GetOperands().end()); + } else { + std::vector sub_ops_intersection; + common_inner_operands.clear(); + absl::c_copy_if(op->GetOperands(), + std::back_inserter(common_inner_operands), + [&](Predicate* sub_op) { + return common_inner_operands_set.count(sub_op) == 1; + }); + } + if (common_inner_operands.empty()) break; + common_inner_operands_set.clear(); + common_inner_operands_set.insert(common_inner_operands.begin(), + common_inner_operands.end()); + } - Predicate* new_pred_ptr = new_pred.get(); - CHECK(interned_and_or_instances_ - .emplace(SignatureForAndOr(pred_kind, operands_slice), - std::move(new_pred)) - .second); - return new_pred_ptr; - } else { - return it->second.get(); + if (common_inner_operands.empty()) { + return MakeInternedAndOr(std::move(simplified_ops), pred_kind); } + + // For all predicates that can be factored out, remove them and recreate the + // subops. + std::vector factored_ops; + for (Predicate* op : simplified_ops) { + std::vector new_sub_op_ops; + absl::c_copy_if(op->GetOperands(), std::back_inserter(new_sub_op_ops), + [&](Predicate* sub_op) { + return std::find(common_inner_operands.begin(), + common_inner_operands.end(), + sub_op) == common_inner_operands.end(); + }); + factored_ops.push_back(MakeAndOrImpl(new_sub_op_ops, !is_and)); + } + + Predicate* new_inner_op = MakeAndOrImpl(factored_ops, is_and); + std::vector outer_ops; + outer_ops.push_back(new_inner_op); + outer_ops.insert(outer_ops.end(), common_inner_operands.begin(), + common_inner_operands.end()); + return MakeAndOrImpl(outer_ops, !is_and); } class DeadnessAnalysisImpl : public DeadnessAnalysis { @@ -507,12 +574,14 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { Status PopulateWithReversePostOrder(absl::Span rpo); bool HasInputsWithMismatchingDeadness(const Node& node) override; void Print() const override; - gtl::FlatMap PredicateMapAsString() const; + absl::flat_hash_map PredicateMapAsString() + const; private: enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly }; - std::vector GetIncomingPreds(Node* n, EdgeKind edge_kind); + Status GetInputPreds(Node* n, EdgeKind edge_kind, + std::vector* result); // Sets the predicate for output `output_idx` of `n` to `pred`. Sets the i'th // bit of `should_revisit` if `pred` is different from the current predicate @@ -549,7 +618,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { Status HandleNode(Node* n, std::vector* should_revisit); const Graph& graph_; - gtl::FlatMap predicate_map_; + absl::flat_hash_map predicate_map_; PredicateFactory predicate_factory_; bool vlog_; }; @@ -558,9 +627,10 @@ TensorId InputEdgeToTensorId(const Edge* e) { return TensorId(e->src()->name(), e->src_output()); } -std::vector DeadnessAnalysisImpl::GetIncomingPreds( - Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind) { - std::vector incoming_preds; +Status DeadnessAnalysisImpl::GetInputPreds( + Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind, + std::vector* result) { + result->clear(); for (const Edge* in_edge : n->in_edges()) { bool should_process = edge_kind == EdgeKind::kDataAndControl || @@ -569,17 +639,27 @@ std::vector DeadnessAnalysisImpl::GetIncomingPreds( if (should_process) { auto it = predicate_map_.find(InputEdgeToTensorId(in_edge)); - CHECK(it != predicate_map_.end()) << n->name(); - incoming_preds.push_back(it->second); + if (it == predicate_map_.end()) { + GraphCycles graph_cycles; + TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph_, &graph_cycles)); + + // If we didn't return with an error above then the graph is probably + // fine and we have a bug in deadness analysis. + return errors::Internal("Could not find input ", in_edge->DebugString(), + " to ", n->name(), + " when visiting the graph in post-order. Most " + "likely indicates a bug in deadness analysis."); + } + result->push_back(it->second); } } - return incoming_preds; + return Status::OK(); } Status DeadnessAnalysisImpl::HandleSwitch(Node* n, std::vector* should_revisit) { - std::vector input_preds = - GetIncomingPreds(n, EdgeKind::kDataAndControl); + std::vector input_preds; + TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); const Edge* pred_edge; TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge)); Predicate* true_switch = predicate_factory_.MakeSymbolPredicate( @@ -608,17 +688,31 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n, } namespace { -const Edge* FindUniqueBackedge(Node* merge) { +Status CreateMultipleNextIterationInputsError(Node* merge) { + std::vector backedges; + for (const Edge* backedge : merge->in_edges()) { + if (backedge->src()->IsNextIteration()) { + backedges.push_back(absl::StrCat(" ", SummarizeNode(*backedge->src()))); + } + } + return errors::InvalidArgument( + "Multiple NextIteration inputs to merge node ", SummarizeNode(*merge), + ": \n", absl::StrJoin(backedges, "\n"), + "\nMerge nodes can have at most one incoming NextIteration edge."); +} + +Status FindUniqueBackedge(Node* merge, const Edge** result) { + *result = nullptr; CHECK(merge->IsMerge()); - const Edge* result = nullptr; for (const Edge* e : merge->in_edges()) { if (e->src()->IsNextIteration()) { - CHECK_EQ(result, nullptr) - << "Multiple backedges to " << merge->DebugString(); - result = e; + if (*result != nullptr) { + return CreateMultipleNextIterationInputsError(merge); + } + *result = e; } } - return result; + return Status::OK(); } // If `backedge_predicate` is equal to `symbolic_predicate` & Step where Step @@ -697,9 +791,12 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, return Status::OK(); } + std::vector input_preds; + TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataOnly, &input_preds)); + // We're visiting this merge for the first time and it is a acyclic merge. - Predicate* input_data_pred = predicate_factory_.MakeOrPredicate( - GetIncomingPreds(n, EdgeKind::kDataOnly)); + Predicate* input_data_pred = + predicate_factory_.MakeOrPredicate(input_preds); SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, should_revisit); return Status::OK(); @@ -710,7 +807,9 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, // of an unvisited backedge. Try to pattern match the predicate expression // for that backedge (which should be visited now) into an and recurrence // for the merge node. - if (const Edge* unique_backedge = FindUniqueBackedge(n)) { + const Edge* unique_backedge; + TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &unique_backedge)); + if (unique_backedge) { if (Predicate* step = DeduceStepPredicate( &predicate_factory_, it->second, predicate_map_[InputEdgeToTensorId(unique_backedge)])) { @@ -741,8 +840,8 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n, std::vector* should_revisit) { // In addition to being alive or dead based on the inputs, a _Recv can also // acquire a dead signal from a _Send. - std::vector input_preds = - GetIncomingPreds(n, EdgeKind::kDataAndControl); + std::vector input_preds; + TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); input_preds.push_back(predicate_factory_.MakeSymbolPredicate( TensorId(n->name(), 0), /*must_be_true=*/false)); SetPredicate(n, {0, Graph::kControlSlot}, @@ -754,8 +853,9 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n, Status DeadnessAnalysisImpl::HandleGeneric(Node* n, std::vector* should_revisit) { // Generally nodes are alive iff all their inputs are alive. - Predicate* pred = predicate_factory_.MakeAndPredicate( - GetIncomingPreds(n, EdgeKind::kDataAndControl)); + std::vector input_preds; + TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); + Predicate* pred = predicate_factory_.MakeAndPredicate(input_preds); for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) { SetPredicate(n, output_idx, pred, should_revisit); } @@ -912,9 +1012,9 @@ DeadnessAnalysis::~DeadnessAnalysis() {} return Status::OK(); } -gtl::FlatMap +absl::flat_hash_map DeadnessAnalysisImpl::PredicateMapAsString() const { - gtl::FlatMap result; + absl::flat_hash_map result; std::vector tensor_ids; for (const auto& kv_pair : predicate_map_) { CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second); diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h index 3df2679c629ce801fc6c9006415dcd27b40c078e..354782374ad070a3d19ddd68bfb986d5a8285e51 100644 --- a/tensorflow/compiler/jit/deadness_analysis_internal.h +++ b/tensorflow/compiler/jit/deadness_analysis_internal.h @@ -16,15 +16,15 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ #define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ +#include "absl/container/flat_hash_map.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace tensorflow { namespace deadness_analysis_internal { // Returns a map describing the predicate each Tensor was mapped to. For // testing purposes only. -using PredicateMapTy = gtl::FlatMap; +using PredicateMapTy = absl::flat_hash_map; Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map); // Returns a map describing the predicate each Tensor was mapped to. For diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 28a56044d5e3795fc3ecf5d1092491b87cb90f01..617e31488c7daeb714c0ff7056b786e4eaf7873f 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -384,10 +384,31 @@ TEST(DeadnessAnalysisTest, OrOfAnd) { EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node())); } -TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) { - // This demonstrates one of the weaknesses in the current approach -- since we - // only do some basic simplifications we can't see that "(A|B)&C" == - // "(A&C)|(B&C)". +TEST(DeadnessAnalysisTest, AndOrDistributiveSimplified) { + // (*A | (~*A & ((~*B & ~*A) | (~*A & *B)))) == #true + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "A"); + ops::Switch sw_1 = CreateSwitch(root, "B"); + Output add0 = + ops::Add(root.WithOpName("and0"), sw_0.output_false, sw_1.output_true); + Output add1 = + ops::Add(root.WithOpName("and1"), sw_0.output_false, sw_1.output_false); + ops::Merge or2(root.WithOpName("or2"), {add0, add1}); + Output add3 = + ops::Add(root.WithOpName("and3"), or2.output, sw_0.output_false); + ops::Merge or4(root.WithOpName("or4"), {add3, sw_0.output_true}); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + EXPECT_EQ(predicate_map[ControlOutputFor(or4.output)], "#true"); +} + +TEST(DeadnessAnalysisTest, AndOrDistributive) { + // (A|B)&C == (A&C)|(B&C) Scope root = Scope::NewRootScope().ExitOnError(); ops::Switch sw_0 = CreateSwitch(root, "0"); @@ -408,7 +429,7 @@ TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add2.node())); + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add3.node())); } TEST(DeadnessAnalysisTest, Ternary) { diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index e0632ff7e48ccea99d469f62ec9d0a3fe8295024..da030b3bcc7aacae2306bec30f4b8927aa042d7c 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" @@ -44,7 +45,6 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/public/session_options.h" @@ -78,7 +78,8 @@ void SortControlInputs(GraphDef* gdef) { namespace { bool AreAllParentsGuaranteedConst( - const Node& n, const gtl::FlatSet& runtime_const_nodes) { + const Node& n, + const absl::flat_hash_set& runtime_const_nodes) { if (n.type_string() == "GuaranteeConst") { // If the current node is itself a cast-to-const, no need // to look at the incoming edges. @@ -101,7 +102,7 @@ bool AreAllParentsGuaranteedConst( void MarkGuaranteedConstants( const Graph& graph, const std::vector>& src_arg_pairs) { - gtl::FlatSet guaranteed_const_nodes; + absl::flat_hash_set guaranteed_const_nodes; std::vector srcs; srcs.reserve(src_arg_pairs.size()); for (const auto& src_arg : src_arg_pairs) { @@ -748,6 +749,12 @@ Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) { graph_->set_versions(graph_in->versions()); } + // TODO(b/116981129): Enhance how the device for the encapsulated subgraph is + // determined. In case of hard placement, ensure all the encapsulated nodes + // have the same requested device, which in turn will be the requested device + // for the entire encapsulated subgraph. In case of soft placement, use a + // deterministic approach to fill in the requested device. Handle co-location + // constraints similarly if they exist. if (device_.empty()) { device_ = node->assigned_device_name().empty() ? node->requested_device() @@ -1102,6 +1109,9 @@ Status Encapsulator::Subgraph::BuildFunctionDef( function_def_name_ = name; FunctionDef fdef; + // Verify that the graph has well-formed control flow structure. + std::vector dummy; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_.get(), &dummy)); TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef)); if (VLOG_IS_ON(1)) { @@ -1357,28 +1367,31 @@ void Encapsulator::Subgraph::GetOutsideCompilationSubgraphNames( Status Encapsulator::GetFunctionNameAttr( Node const* node, string* attr, string* outside_compilation_attr) const { - Status s = GetNodeAttr(node->attrs(), group_attribute_, attr); - if (s.code() == error::Code::NOT_FOUND) { - // Return empty attr if there's no group_attribute. - attr->clear(); - } else { - TF_RETURN_IF_ERROR(s); - } - bool has_group_attr = s.ok(); - s = GetNodeAttr(node->attrs(), outside_compilation_attribute_, - outside_compilation_attr); - if (s.code() == error::Code::NOT_FOUND) { - // Return empty attr if there's no outside_compilation attribute. - outside_compilation_attr->clear(); - } else { - TF_RETURN_IF_ERROR(s); - if (!has_group_attr) { - return errors::InvalidArgument( - "Node ", node->name(), " has ", outside_compilation_attribute_, - " attribute but no ", group_attribute_, " attribute."); + AttrSlice attrs = node->attrs(); + attr->clear(); + outside_compilation_attr->clear(); + bool found_group_attribute = false; + bool found_outside_compilation_attribute = false; + for (const auto& node_attr : attrs) { + if (node_attr.first == group_attribute_) { + TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string")); + *attr = node_attr.second.s(); + found_group_attribute = true; + } else if (node_attr.first == outside_compilation_attribute_) { + TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string")); + *outside_compilation_attr = node_attr.second.s(); + found_outside_compilation_attribute = true; } + if (found_group_attribute && found_outside_compilation_attribute) break; + } + + if (found_outside_compilation_attribute && !found_group_attribute) { + return errors::InvalidArgument( + "Node ", node->name(), " has ", outside_compilation_attribute_, + " attribute but no ", group_attribute_, " attribute."); + } else { + return Status::OK(); } - return Status::OK(); } bool IsInSubgraph(const string& func_id, const string& outside_compilation_id) { @@ -1521,9 +1534,6 @@ Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) { for (auto& entry : subgraphs_) { Subgraph& subgraph = entry.second; FixupSourceAndSinkEdges(subgraph.GetGraph()); - // Verify that the graph has well-formed control flow structure. - std::vector dummy; - TF_RETURN_IF_ERROR(BuildControlFlowInfo(subgraph.GetGraph(), &dummy)); } if (VLOG_IS_ON(1)) { diff --git a/tensorflow/compiler/jit/encapsulate_util.cc b/tensorflow/compiler/jit/encapsulate_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..870a265f299969b670c564d2ce3d4847aa71fe6e --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_util.cc @@ -0,0 +1,94 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/encapsulate_util.h" +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/jit/shape_inference.h" + +namespace tensorflow { + +namespace { + +// Returns string attribute value for the node if the attribute is present, +// otherwise returns empty optional value. +absl::optional GetStringAttr(const Node& n, const string& attr_name) { + auto attr = n.attrs().Find(attr_name); + if (!attr) { + return absl::nullopt; + } else { + return attr->s(); + } +} + +} // namespace + +const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes"; + +Status PerformStaticShapeInferenceBeforeEncapsulation( + Graph* g, const string& xla_computation_attr_name, + const string& outside_compilation_attr_name) { + // Find all outside compilation to XLA computation data edges. + std::unordered_set outside_compilation_send_nodes; + for (auto e : g->edges()) { + if (e->IsControlEdge()) { + continue; + } + + auto src_computation = GetStringAttr(*e->src(), xla_computation_attr_name); + auto dst_computation = GetStringAttr(*e->dst(), xla_computation_attr_name); + if (!src_computation || !dst_computation || + *src_computation != *dst_computation) { + continue; + } + + auto src_outside_compilation = + GetStringAttr(*e->src(), outside_compilation_attr_name); + auto dst_outside_compilation = + GetStringAttr(*e->dst(), outside_compilation_attr_name); + if (src_outside_compilation && !dst_outside_compilation) { + outside_compilation_send_nodes.insert(e->src()); + } + } + + // Perform shape inference. + std::map arg_shapes; + GraphShapeInfo shape_info; + TF_RETURN_IF_ERROR( + InferShapes(g, arg_shapes, /*fnlib_def=*/nullptr, &shape_info)); + + // Add attribute for output shapes. + for (Node* n : outside_compilation_send_nodes) { + auto iter = shape_info.find(n->name()); + if (iter == shape_info.end()) { + continue; + } + + std::vector output_shapes; + std::transform(iter->second.begin(), iter->second.end(), + std::back_inserter(output_shapes), + [](const InferredShape& inferred_shape) { + return inferred_shape.shape; + }); + n->AddAttr(kXlaInferredShapesAttrName, output_shapes); + } + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h new file mode 100644 index 0000000000000000000000000000000000000000..bc46521b98f43d6bfb1c115903d93dcd8006dc01 --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_util.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file contains some utility functions for encapsulating XLA computation +// in host graph and encapsulating outside compilation in XLA computation. + +#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_ + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Attribute marking output tensor shapes inferred by XLA. Attribute value is +// a list of PartialTensorShape objects. +extern const char kXlaInferredShapesAttrName[]; + +// Infer output shapes for outside compilation nodes which have output data +// edges to XLA computation nodes. These shapes will be used later by XLA +// compiler as output shapes of the outside compilation's XlaHostCompute op. +// XLA computation nodes will be mark by attr `xla_computation_attr_name`; +// outside compilation nodes will be marked by both attr +// `xla_computation_attr_name` and `outside_compilation_attr_name`. +// +// Those outside compilation nodes will be marked with attribute +// `kXlaInferredShapesAttrName`. +// +// We have to perform shape inference before encapsulation because after +// encapsulation, some nodes will be encapsulated into function call, and shape +// inference does not handle function call at the moment. +Status PerformStaticShapeInferenceBeforeEncapsulation( + Graph* g, const string& xla_computation_attr_name, + const string& outside_compilation_attr_name); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_ diff --git a/tensorflow/compiler/jit/encapsulate_util_test.cc b/tensorflow/compiler/jit/encapsulate_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..53bdf55ab2420f1bce2887c9214211fad3b0396b --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_util_test.cc @@ -0,0 +1,68 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/encapsulate_util.h" + +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) { + // Build the graph: + // "add" = "const_0" + "const_1" + // "identity" = "add" + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {2}); + Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {2}); + Output add = ops::Add(s.WithOpName("add"), const_0, const_1); + Output identity = ops::Identity(s.WithOpName("identity"), add); + Graph g(OpRegistry::Global()); + TF_CHECK_OK(s.ToGraph(&g)); + + // "add" node is outside compilation node, "identity" node is XLA node. + auto node_index = g.BuildNodeNameIndex(); + Node *add_node = node_index["add"], *identity_node = node_index["identity"]; + add_node->AddAttr("_xla", "cluster"); + add_node->AddAttr("_oc", "cluster"); + identity_node->AddAttr("_xla", "cluster"); + TF_CHECK_OK( + PerformStaticShapeInferenceBeforeEncapsulation(&g, "_xla", "_oc")); + + // Check that only "add" node now has _xla_inferred_shapes attr. + std::vector nodes_with_inferred_shape; + for (Node* n : g.nodes()) { + if (HasNodeAttr(n->def(), kXlaInferredShapesAttrName)) { + nodes_with_inferred_shape.push_back(n); + } + } + EXPECT_EQ(nodes_with_inferred_shape.size(), 1); + EXPECT_EQ(nodes_with_inferred_shape[0], add_node); + std::vector output_shapes; + TF_CHECK_OK(GetNodeAttr(add_node->attrs(), kXlaInferredShapesAttrName, + &output_shapes)); + EXPECT_EQ(output_shapes.size(), 1); + TensorShapeProto shape_proto; + output_shapes[0].AsProto(&shape_proto); + EXPECT_EQ(shape_proto.dim_size(), 1); + EXPECT_EQ(shape_proto.dim(0).size(), 2); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index 97ef8cd3cb3fba54259fc413e0a3d3e75a89c431..2ce6fa73fc448ca83fa392aa909cb385453eb8b6 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -15,13 +15,13 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -62,7 +62,7 @@ DataType EdgeType(const Edge* edge) { } // Adds the control inputs of `node` to `*deps`. -void AddControlInputs(const Node& node, gtl::FlatSet* deps) { +void AddControlInputs(const Node& node, absl::flat_hash_set* deps) { for (const Edge* edge : node.in_edges()) { if (edge->IsControlEdge()) { deps->insert(edge->src()); @@ -71,7 +71,7 @@ void AddControlInputs(const Node& node, gtl::FlatSet* deps) { } // Adds the control outputs of `node` to `*deps`. -void AddControlOutputs(const Node& node, gtl::FlatSet* deps) { +void AddControlOutputs(const Node& node, absl::flat_hash_set* deps) { for (const Edge* edge : node.out_edges()) { if (edge->IsControlEdge()) { deps->insert(edge->dst()); @@ -246,7 +246,7 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, // Data and control inputs to the new XlaLaunch node. std::vector> data_inputs(num_inputs); - gtl::FlatSet control_inputs; + absl::flat_hash_set control_inputs; DataTypeVector arg_types(num_args); AddControlInputs(*launch, &control_inputs); @@ -266,7 +266,7 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, // Outputs. const int num_outputs = launch->output_types().size(); - gtl::FlatSet control_outputs; + absl::flat_hash_set control_outputs; std::vector>> data_outputs(num_outputs); DataTypeVector output_types(num_outputs); @@ -297,7 +297,9 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, // Target the XLA CPU/GPU backends. VLOG(2) << "Replacing with XlaLaunch"; + VLOG(2) << "Device is " << launch->requested_device(); def.set_op("XlaLaunch"); + def.set_device(launch->requested_device()); AddNodeAttr("Tconstants", DataTypeVector{}, &def); AddNodeAttr("Targs", arg_types, &def); AddNodeAttr("Nresources", num_variables, &def); diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc index f643fb0cfe136caba42272d72f3972ec63a94bf3..192e1c7b32467d80cef6ff61a1c7078f8dea9dfb 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" -#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_op.h" +#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h" #include "tensorflow/compiler/tf2xla/test_util.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -55,6 +55,7 @@ static std::unique_ptr MakeOuterGraph( .Input(u.node()->name(), 0, DT_RESOURCE) .Input(v.node()->name(), 0, DT_RESOURCE) .Input(w.node()->name(), 0, DT_RESOURCE) + .Device("/gpu:0") .Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0") .Attr("_variable_start_index", 4) .Finalize(&def)); @@ -107,10 +108,11 @@ static std::unique_ptr MakeBodyGraph() { auto add_attrs = [](Node* node) { node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); + node->set_requested_device("/gpu:0"); }; auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1); - + add_attrs(b_identity.node()); auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT); add_attrs(read_u.node()); auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT); @@ -215,6 +217,7 @@ TEST(EncapsulateXlaComputations, Encapsulate) { auto add_attrs = [](Node* node) { node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); + node->set_requested_device("/gpu:0"); }; auto b_identity = ops::Identity(scope.WithOpName("B_identity"), b); @@ -253,7 +256,7 @@ TEST(EncapsulateXlaComputations, Encapsulate) { TF_ASSERT_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def)); - std::unordered_map index = BuildNodeIndex(*graph); + std::unordered_map index = graph->BuildNodeNameIndex(); string function = index.at("launch0")->type_string(); // Tests the outer graph is as expected. @@ -288,7 +291,8 @@ TEST(EncapsulateXlaComputations, Encapsulate) { // function. Encapsulation should be deterministic to avoid recompilation. TF_ASSERT_OK( EncapsulateXlaComputationsPass::Encapsulate(&graph_copy, &flib_def)); - std::unordered_map index_copy = BuildNodeIndex(*graph_copy); + std::unordered_map index_copy = + graph_copy->BuildNodeNameIndex(); string function_copy = index_copy.at("launch0")->type_string(); EXPECT_EQ(function, function_copy); } @@ -317,8 +321,8 @@ TEST(EncapsulateXlaComputations, BuildXlaLaunchOp) { NameAttrList function; function.set_name("launch0"); auto launch = ops::XlaLaunch( - scope.WithOpName("launch0"), std::initializer_list{}, - std::initializer_list{a, b, c, d}, + scope.WithOpName("launch0").WithDevice("/gpu:0"), + std::initializer_list{}, std::initializer_list{a, b, c, d}, std::initializer_list{u, v, w}, DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function); diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 0839f1cb3dafd9af533631c73a37a1df7172ac0b..26cb3af9d69ba1877c67853cde28d2477d394efc 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -26,6 +26,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/kernels:variable_ops", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", ], alwayslink = 1, diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index a85006eb0378688dffd634c13a392b02e379c7f2..2268d9042860f6556cb69469ee52ad7cbbb81954 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/kernels/xla_ops.h" +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -163,7 +164,7 @@ class XlaExecutableClosureStore { private: mutex mutex_; int64 key_counter_ GUARDED_BY(mutex_); - gtl::FlatMap closures_ GUARDED_BY(mutex_); + absl::flat_hash_map closures_ GUARDED_BY(mutex_); TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore); }; @@ -218,7 +219,7 @@ static Status BuildCompilationCache(OpKernelContext* ctx, static Status CompileToLocalExecutable( OpKernelContext* ctx, const NameAttrList& function, const XlaPlatformInfo& platform_info, absl::Span resources, - absl::Span constants, xla::LocalClient** client, + absl::Span constants, bool lazy, xla::LocalClient** client, std::map* variables, const XlaCompiler::CompilationResult** kernel, xla::LocalExecutable** executable) { @@ -276,7 +277,10 @@ static Status CompileToLocalExecutable( compile_options.always_return_tuple = false; return cache->Compile(options, function, constant_args, *variables, ctx, - kernel, executable, compile_options); + compile_options, + lazy ? XlaCompilationCache::CompileMode::kLazy + : XlaCompilationCache::CompileMode::kStrict, + kernel, executable); } void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { @@ -290,8 +294,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK( ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_, - constants_, &client, &variables, &kernel, - &executable)); + constants_, /*lazy=*/false, &client, + &variables, &kernel, &executable)); se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; @@ -393,9 +397,12 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx) resources_(ResourcesVector(ctx)), function_(FunctionAttr(ctx)) { OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("must_compile", &must_compile_)); } void XlaCompileOp::Compute(OpKernelContext* ctx) { + VLOG(3) << "XlaCompileOp " << def().name() + << (must_compile_ ? "(must-compile)" : ""); xla::LocalClient* client; const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; @@ -403,8 +410,24 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK( ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_, - constants_, &client, &variables, &kernel, - &executable)); + constants_, /*lazy=*/!must_compile_, + &client, &variables, &kernel, &executable)); + + AllocatorAttributes host_alloc_attrs; + host_alloc_attrs.set_gpu_compatible(true); + host_alloc_attrs.set_on_host(true); + Allocator* cpu_allocator = ctx->device()->GetAllocator(host_alloc_attrs); + + if (!executable) { + DCHECK(!must_compile_); + Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({})); + + Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({})); + compilation_successful.scalar()() = false; + ctx->set_output(0, Tensor(cpu_allocator, DT_STRING, TensorShape({}))); + ctx->set_output(1, compilation_successful); + return; + } // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even // if it didn't have to compile the cluster because of a compilation-cache @@ -414,13 +437,6 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure( client, executable, kernel, std::move(variables), constants_.size())); - Allocator* cpu_allocator = [&] { - AllocatorAttributes host_alloc_attrs; - host_alloc_attrs.set_gpu_compatible(true); - host_alloc_attrs.set_on_host(true); - return ctx->device()->GetAllocator(host_alloc_attrs); - }(); - Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({})); compilation_key.flat()(0) = key; @@ -436,6 +452,7 @@ XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) : OpKernel(ctx) { } void XlaRunOp::Compute(OpKernelContext* ctx) { + VLOG(3) << "XlaRunOp " << def().name(); Tensor key_tensor = ctx->input(ctx->num_inputs() - 1); const XlaExecutableClosureStore::KeyT& key = key_tensor.flat()(0); @@ -490,6 +507,8 @@ REGISTER_KERNEL_BUILDER(Name("_XlaCompile").Device(DEVICE_CPU), XlaCompileOp); REGISTER_KERNEL_BUILDER(Name("_XlaCompile") .Device(DEVICE_GPU) .HostMemory("constants") + .HostMemory("key") + .HostMemory("compilation_successful") .HostMemory("resources"), XlaCompileOp); diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h index 489d26eb30a66646158f39ea3fc6f55759c7f88e..ac90837e0d90943b93e2cdb01a30fa0837ba94df 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.h +++ b/tensorflow/compiler/jit/kernels/xla_ops.h @@ -151,6 +151,8 @@ class XlaCompileOp : public OpKernel { NameAttrList function_; XlaPlatformInfo platform_info_; + + bool must_compile_; }; class XlaRunOp : public OpKernel { diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD index 07c5b2318851ed506711b9ee00c66fe680a3afd8..d8fe4026f51d8aa4b027aeedf0795ad30e28d986 100644 --- a/tensorflow/compiler/jit/legacy_flags/BUILD +++ b/tensorflow/compiler/jit/legacy_flags/BUILD @@ -39,3 +39,15 @@ cc_library( "//tensorflow/core:lib", ], ) + +cc_library( + name = "build_xla_ops_pass_flags", + srcs = ["build_xla_ops_pass_flags.cc"], + hdrs = ["build_xla_ops_pass_flags.h"], + deps = + [ + "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) diff --git a/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc new file mode 100644 index 0000000000000000000000000000000000000000..58157d2b9800a2e8269533607c2ea688ff4e7766 --- /dev/null +++ b/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include // NOLINT + +#include "tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace legacy_flags { +namespace { + +BuildXlaOpsPassFlags* flags; +std::vector* flag_list; +std::once_flag flags_init; + +void AllocateAndParseFlags() { + flags = new BuildXlaOpsPassFlags; + flags->tf_xla_enable_lazy_compilation = false; + flag_list = new std::vector({ + Flag("tf_xla_enable_lazy_compilation", + &flags->tf_xla_enable_lazy_compilation, ""), + }); + xla::legacy_flags::ParseFlagsFromEnv(*flag_list); +} + +} // namespace + +const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags() { + std::call_once(flags_init, &AllocateAndParseFlags); + return *flags; +} +} // namespace legacy_flags +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h new file mode 100644 index 0000000000000000000000000000000000000000..539314cbf72d38ed973b8a526aa6424b19ef344d --- /dev/null +++ b/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_BUILD_XLA_OPS_PASS_FLAGS_H_ +#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_BUILD_XLA_OPS_PASS_FLAGS_H_ + +namespace tensorflow { +namespace legacy_flags { + +// Flags for the build_xla_ops pass. +struct BuildXlaOpsPassFlags { + // Enables lazy compilation for TF/XLA (only when auto-clustering) if true. + // Defaults to false. + bool tf_xla_enable_lazy_compilation; +}; + +// Parses the flags in BuildXlaOpsPassFlags from the TF_XLA_FLAGS environment +// variable and returns a reference to the parsed copy. Parses TF_XLA_FLAGS +// only the first time this routine is called. +const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags(); + +} // namespace legacy_flags +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_BUILD_XLA_OPS_PASS_FLAGS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 133d9823609efa688d8a8f7a066ccbfefc75c15b..4f0c370e65159c89c91ea58733f20f852d9acc99 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" @@ -42,7 +43,6 @@ limitations under the License. #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" @@ -371,7 +371,7 @@ bool IsXlaFusable(const NodeDef& node) { Status FindCompilationCandidates( const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env, const std::function& is_compilable_fn, - OrderedNodeSet* candidates, gtl::FlatSet* isolated_nodes) { + OrderedNodeSet* candidates, absl::flat_hash_set* isolated_nodes) { OptimizerOptions opts; std::unique_ptr pflr( new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION, @@ -849,7 +849,7 @@ Status MarkForCompilationPass::RunImpl( Graph* graph = options.graph->get(); OrderedNodeSet compilation_candidates; - gtl::FlatSet isolated_nodes; + absl::flat_hash_set isolated_nodes; TF_RETURN_IF_ERROR(FindCompilationCandidates( *graph, options.flib_def, (options.session_options != nullptr) ? options.session_options->env diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 4f9145b4799d5fbaaae2bafd47edec7fa6e463a3..2a80c745e3fcebf97bcccb03551feb3d6fb9f831 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" @@ -61,10 +62,10 @@ std::unordered_map GetClusters(const Graph& graph) { return ids; } -gtl::FlatMap> GetClusterSets( +absl::flat_hash_map> GetClusterSets( const Graph& g, std::vector* cluster_names = nullptr) { CHECK(cluster_names == nullptr || cluster_names->empty()); - gtl::FlatMap> cluster_sets; + absl::flat_hash_map> cluster_sets; for (const auto& p : GetClusters(g)) { cluster_sets[p.second].push_back(p.first); } @@ -566,7 +567,7 @@ TEST(XlaCompilationTest, ResourcesClusteringAllowed) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_EXPECT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - gtl::FlatMap> cluster_sets = + absl::flat_hash_map> cluster_sets = GetClusterSets(*graph); ASSERT_EQ(cluster_sets.size(), 1); std::vector expected_clustered_nodes = {"AssignmentW", "ReadR", @@ -586,7 +587,7 @@ TEST(XlaCompilationTest, ResourcesClusteringDisallowed) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_EXPECT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - gtl::FlatMap> cluster_sets = + absl::flat_hash_map> cluster_sets = GetClusterSets(*graph); ASSERT_EQ(cluster_sets.size(), 1); std::vector expected_clustered_nodes = {"AssignmentW", @@ -616,7 +617,7 @@ TEST(XlaCompilationTest, ChainOfOps) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); std::vector cluster_names; - gtl::FlatMap> cluster_sets = + absl::flat_hash_map> cluster_sets = GetClusterSets(*graph, &cluster_names); ASSERT_EQ(cluster_sets.size(), 2); diff --git a/tensorflow/compiler/jit/node_matchers.cc b/tensorflow/compiler/jit/node_matchers.cc index d8ace628e6b76e011ecddd4d526efc4db9c9237e..a09a6eb1553cb4bcf5587a7602097a40b64cfcdf 100644 --- a/tensorflow/compiler/jit/node_matchers.cc +++ b/tensorflow/compiler/jit/node_matchers.cc @@ -19,7 +19,10 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" #include "absl/strings/str_split.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor.pb.h" namespace tensorflow { @@ -28,6 +31,7 @@ namespace matchers { namespace { using impl::NodeMatcherProperties; +using impl::OutEdge; string IndentAllButFirstLine(absl::string_view text) { std::vector lines = absl::StrSplit(text, '\n'); @@ -99,8 +103,6 @@ bool MatchAndExplainTensor(const Tensor& tensor, const Tensor& expected_tensor, } } -using Input = std::pair; - struct NodeMatcher : public ::testing::MatcherInterface { bool MatchAndExplain( const Node* node, @@ -191,6 +193,29 @@ struct NodeMatcher : public ::testing::MatcherInterface { } return false; } + + const AttrValueMap attr_value_map = node->def().attr(); + for (const auto& attr_kv_pair : attrs) { + auto it = attr_value_map.find(attr_kv_pair.first); + if (it == attr_value_map.end()) { + if (listener->IsInterested()) { + *listener << "did not find attribute named \"" << attr_kv_pair.first + << "\" in node"; + } + return false; + } + if (!AreAttrValuesEqual(it->second, attr_kv_pair.second)) { + if (listener->IsInterested()) { + *listener << "attribute named " << attr_kv_pair.first + << " does not match value; expected: \"" + << SummarizeAttrValue(attr_kv_pair.second) + << "\", found: \"" << SummarizeAttrValue(it->second) + << "\""; + } + return false; + } + } + return true; } @@ -232,7 +257,7 @@ struct NodeMatcher : public ::testing::MatcherInterface { *os << "matching " << ss.str(); } else { int edge_idx = 0; - for (const ::testing::Matcher& matcher : (*input_matchers)) { + for (const ::testing::Matcher& matcher : (*input_matchers)) { *os << "\n [" << edge_idx << "] matching ("; ::std::stringstream ss; matcher.DescribeTo(&ss); @@ -250,6 +275,19 @@ struct NodeMatcher : public ::testing::MatcherInterface { control_dep_set->DescribeTo(os); } + if (!attrs.empty()) { + printed_something = true; + std::vector attrs_str; + absl::c_transform(attrs, std::back_inserter(attrs_str), + [](const std::pair& attr_kv_pair) { + return absl::StrCat( + attr_kv_pair.first, "->", + SummarizeAttrValue(attr_kv_pair.second)); + }); + *os << " and attr values matching [" << absl::StrJoin(attrs_str, ", ") + << "]"; + } + if (!printed_something) { *os << "is any node"; } @@ -266,7 +304,7 @@ struct NodeMatcher : public ::testing::MatcherInterface { } ::testing::StringMatchResultListener inner_listener; - Input input = {edge->src(), edge->src_output()}; + OutEdge input = {edge->src(), edge->src_output()}; if ((*input_matchers)[input_idx].MatchAndExplain(input, &inner_listener)) { return true; } @@ -286,22 +324,24 @@ struct NodeMatcher : public ::testing::MatcherInterface { absl::optional name; absl::optional assigned_device; absl::optional constant_value; - absl::optional>> input_matchers; + absl::optional>> input_matchers; absl::optional<::testing::Matcher>> control_dep_set; + std::map attrs; }; // Matches a dst and dst_output on an input edge. Today we only use this with // dst_output=0 but we will eventually need to support multi-output operations. -class InputMatcher : public ::testing::MatcherInterface { +class OutEdgeMatcher : public ::testing::MatcherInterface { public: - InputMatcher(::testing::Matcher src_matcher, int src_output) - : src_matcher_(std::move(src_matcher)), src_output_(src_output) {} + OutEdgeMatcher(::testing::Matcher src_matcher, int src_oidx) + : src_matcher_(std::move(src_matcher)), src_oidx_(src_oidx) {} bool MatchAndExplain( - Input input, ::testing::MatchResultListener* listener) const override { + OutEdge out_edge, + ::testing::MatchResultListener* listener) const override { ::testing::StringMatchResultListener inner_listener; - if (!src_matcher_.MatchAndExplain(input.first, &inner_listener)) { + if (!src_matcher_.MatchAndExplain(out_edge.first, &inner_listener)) { if (listener->IsInterested()) { *listener << "\nsource does not match expected "; src_matcher_.DescribeTo(listener->stream()); @@ -312,10 +352,10 @@ class InputMatcher : public ::testing::MatcherInterface { } return false; } - if (input.second != src_output_) { + if (out_edge.second != src_oidx_) { if (listener->IsInterested()) { - *listener << "\nexpected output slot to be " << src_output_ - << " but found " << input.second; + *listener << "\nexpected output slot to be " << src_oidx_ + << " but found " << out_edge.second; } return false; } @@ -324,31 +364,21 @@ class InputMatcher : public ::testing::MatcherInterface { } void DescribeTo(::std::ostream* os) const override { - if (src_output_) { - *os << "output slot: " << src_output_ << ", source: ("; + if (src_oidx_) { + *os << "output slot: " << src_oidx_ << ", source: ("; } src_matcher_.DescribeTo(os); - if (src_output_) { + if (src_oidx_) { *os << ")"; } } private: ::testing::Matcher src_matcher_; - int src_output_; + int src_oidx_; }; - -std::vector<::testing::Matcher> NodeMatchersToInputMatchers( - absl::Span> node_matchers) { - std::vector<::testing::Matcher> result; - absl::c_transform(node_matchers, std::back_inserter(result), - [](::testing::Matcher n) { - return ::testing::MakeMatcher(new InputMatcher(n, 0)); - }); - return result; -} } // namespace ::testing::Matcher impl::NodeWith( @@ -375,10 +405,9 @@ std::vector<::testing::Matcher> NodeMatchersToInputMatchers( matcher->assigned_device = prop.assigned_device(); } - if (prop.input_nodes()) { + if (prop.inputs()) { DCHECK(!matcher->input_matchers); - matcher->input_matchers = - NodeMatchersToInputMatchers(*prop.input_nodes()); + matcher->input_matchers = *prop.inputs(); } if (prop.control_deps()) { @@ -386,6 +415,11 @@ std::vector<::testing::Matcher> NodeMatchersToInputMatchers( matcher->control_dep_set = ::testing::UnorderedElementsAreArray(*prop.control_deps()); } + + if (prop.attr()) { + auto insert_result = matcher->attrs.insert(*prop.attr()); + DCHECK(insert_result.second); + } } return ::testing::MakeMatcher(matcher); @@ -412,12 +446,12 @@ impl::NodeMatcherProperties AssignedDevice(string assigned_device) { } impl::NodeMatcherProperties impl::Inputs( - absl::Span> inputs) { - std::vector<::testing::Matcher> inputs_vector; + absl::Span> inputs) { + std::vector<::testing::Matcher> inputs_vector; absl::c_copy(inputs, std::back_inserter(inputs_vector)); impl::NodeMatcherProperties props; - props.set_input_nodes(std::move(inputs_vector)); + props.set_inputs(std::move(inputs_vector)); return props; } @@ -431,6 +465,19 @@ impl::NodeMatcherProperties impl::CtrlDeps( return props; } +std::pair impl::AttrLiteralHelper( + const std::pair& bool_attr) { + AttrValue attr_value; + attr_value.set_b(bool_attr.second); + return {bool_attr.first, attr_value}; +} + +impl::NodeMatcherProperties impl::Attr(std::pair attr) { + impl::NodeMatcherProperties props; + props.set_attr(std::move(attr)); + return props; +} + NodeMatcherProperties ConstantValue( const ::tensorflow::Input::Initializer& val) { TF_CHECK_OK(val.status); @@ -443,6 +490,10 @@ NodeMatcherProperties ConstantValue( const ::tensorflow::Input::Initializer& val) { return NodeWith(ConstantValue(val)); } +::testing::Matcher Out( + int oidx, ::testing::Matcher node_matcher) { + return ::testing::MakeMatcher(new OutEdgeMatcher(node_matcher, oidx)); +} } // namespace matchers Node* FindNodeByName(Graph* g, absl::string_view name) { @@ -455,4 +506,7 @@ Node* FindNodeByName(Graph* g, absl::string_view name) { return nullptr; } } // namespace testing + +void PrintTo(const Node* n, ::std::ostream* os) { *os << SummarizeNode(*n); } +void PrintTo(Node* n, ::std::ostream* os) { *os << SummarizeNode(*n); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/node_matchers.h b/tensorflow/compiler/jit/node_matchers.h index 0437a7e95c1eb3bdcdbe24a440dd90a5943c0894..35c2f5fd7b533d0e8716dc6c70c21afe9a32c9c8 100644 --- a/tensorflow/compiler/jit/node_matchers.h +++ b/tensorflow/compiler/jit/node_matchers.h @@ -19,7 +19,7 @@ limitations under the License. // // tensorflow::Node* node = ...; // EXPECT_THAT(node, NodeWith(Name("name"), Op("op"), -// Inputs(NodeWith(Name("input"))))) +// Inputs(Out(3, NodeWith(Name("input")))))) // // Matchable node properties (the expressions that go inside NodeWith(...)) // are: @@ -32,7 +32,8 @@ limitations under the License. // - AssignedDevice(string): matches the assigned device exactly. // // - Inputs(): matches the list of non-control inputs to the node -// exactly (i.e. does not match a suffix or a prefix). +// exactly (i.e. does not match a suffix or a prefix) where each element +// matches an output of a node (see Out(idx, node) below). // // - CtrlDeps(): matches the list of control dependences on the // node exactly but in any order. @@ -40,10 +41,16 @@ limitations under the License. // - ConstantValue(tensorflow::Input::Initializer init): matches a Const node // with the constant value `init`. Implies Op("Const"). // -// Node properties may not be repeated in a single NodeWith(...) matcher. -// E.g. NodeWith(Op("Foo"), Op("Bar")) will CHECK-fail. Since ConstantValue -// implies Op("Const"), a single NodeWith matcher can't have both -// ConstantValue(...) and Op(...). +// - Attr(name, value): Matches a single attribute with name `name` and value +// `value`. Right now only boolean values are supported. +// +// Overlapping node properties may not be repeated in a single NodeWith(...) +// matcher. E.g. NodeWith(Op("Foo"), Op("Bar")) will CHECK-fail. Since +// ConstantValue implies Op("Const"), a single NodeWith matcher can't have both +// ConstantValue(...) and Op(...). Multiple Attr() values can be combined as +// long as the attribute names are different. +// +// Out(idx, node) matches the `idx`'th output of a node that matches `node`. #ifndef TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ #define TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ @@ -66,6 +73,8 @@ namespace matchers { namespace impl { +using OutEdge = std::pair; + // ----------------------------------------------------------------------------- // Implementation details. @@ -74,6 +83,8 @@ namespace impl { class NodeMatcherProperties { public: using NodeSeqMatcher = std::vector<::testing::Matcher>; + using InputSeqMatcher = std::vector<::testing::Matcher>; + using AttrKeyValuePair = std::pair; const absl::optional& name() const { return name_; } const absl::optional& op() const { return op_; } @@ -83,12 +94,13 @@ class NodeMatcherProperties { const absl::optional& constant_value() const { return constant_value_; } - const absl::optional& input_nodes() const { - return input_nodes_; + const absl::optional& inputs() const { + return input_matchers_; } const absl::optional& control_deps() const { return control_deps_; } + const absl::optional& attr() const { return attr_; } void set_name(string name) { DCHECK(IsEmpty()); @@ -111,9 +123,9 @@ class NodeMatcherProperties { op_ = "Const"; } - void set_input_nodes(NodeSeqMatcher input_nodes) { + void set_inputs(InputSeqMatcher inputs) { DCHECK(IsEmpty()); - input_nodes_ = std::move(input_nodes); + input_matchers_ = std::move(inputs); } void set_control_deps(NodeSeqMatcher control_deps) { @@ -121,9 +133,14 @@ class NodeMatcherProperties { control_deps_ = std::move(control_deps); } + void set_attr(AttrKeyValuePair attr) { + DCHECK(IsEmpty()); + attr_ = std::move(attr); + } + bool IsEmpty() const { - return !name().has_value() && !op().has_value() && - !input_nodes().has_value() && !control_deps().has_value(); + return !name().has_value() && !op().has_value() && !inputs().has_value() && + !control_deps().has_value() && !attr().has_value(); } private: @@ -131,18 +148,24 @@ class NodeMatcherProperties { absl::optional op_; absl::optional assigned_device_; absl::optional constant_value_; - absl::optional input_nodes_; + absl::optional input_matchers_; absl::optional control_deps_; + absl::optional attr_; }; ::testing::Matcher NodeWith( absl::Span props); impl::NodeMatcherProperties Inputs( - absl::Span> inputs); + absl::Span> inputs); impl::NodeMatcherProperties CtrlDeps( absl::Span> control_deps); + +impl::NodeMatcherProperties Attr(std::pair attrs); + +std::pair AttrLiteralHelper( + const std::pair& bool_attr); } // namespace impl // ----------------------------------------------------------------------------- @@ -157,6 +180,13 @@ impl::NodeMatcherProperties Op(string op); // Matches a node with assigned device `assigned_device`. impl::NodeMatcherProperties AssignedDevice(string assigned_device); +// Matches a node with a boolean typed attrbute named `name` and with value +// `value`. +template +impl::NodeMatcherProperties Attr(const string& name, ValueTy value) { + return impl::Attr({impl::AttrLiteralHelper({name, value})}); +} + // Matches a node with inputs `inputs`. // // `inputs` are ordered; `inputs`[i] must match input i. @@ -165,6 +195,15 @@ impl::NodeMatcherProperties Inputs(Ts... inputs) { return impl::Inputs({inputs...}); } +// Matches the `idx`'th output of a node that matches `node`. +::testing::Matcher Out(int oidx, + ::testing::Matcher node); + +// Matches the first output of a node that matches `node`. +::testing::Matcher Out(::testing::Matcher node) { + return Out(0, node); +} + // Matches a node with control dependences `control_deps`. // // `control_deps` are unordered and will match the control deps of a node in any @@ -192,6 +231,9 @@ template // If `g` has a node named `name` returns it, otherwise returns null. Node* FindNodeByName(Graph* g, absl::string_view name); } // namespace testing + +void PrintTo(const Node* n, ::std::ostream* os); +void PrintTo(Node* n, ::std::ostream* os); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ diff --git a/tensorflow/compiler/jit/node_matchers_test.cc b/tensorflow/compiler/jit/node_matchers_test.cc index 93a8994307b38ac240c22d0a18268638ac7620ae..c3f0dfece85573d71dbfa21eba5af70b674fe71e 100644 --- a/tensorflow/compiler/jit/node_matchers_test.cc +++ b/tensorflow/compiler/jit/node_matchers_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/control_flow_ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/math_ops.h" namespace tensorflow { @@ -27,12 +29,14 @@ namespace { using ::testing::_; using testing::matchers::AssignedDevice; +using testing::matchers::Attr; using testing::matchers::ConstantValue; using testing::matchers::CtrlDeps; using testing::matchers::Inputs; using testing::matchers::Name; using testing::matchers::NodeWith; using testing::matchers::Op; +using testing::matchers::Out; template string Explain(const T& t, const M& m) { @@ -61,7 +65,7 @@ TEST(NodeMatchers, CheckAgainstConstant) { "\nexpected op Add but found Placeholder"); EXPECT_EQ(Explain(placeholder.node(), NodeWith(Name("add"))), "\nexpected name add but found placeholder"); - EXPECT_EQ(Explain(placeholder.node(), NodeWith(Inputs(NodeWith()))), + EXPECT_EQ(Explain(placeholder.node(), NodeWith(Inputs(Out(NodeWith())))), "\nexpected 1 inputs but node has 0"); } @@ -74,18 +78,19 @@ TEST(NodeMatchers, CheckAgainstBinary) { ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT); Output add = ops::Add(root.WithOpName("add"), placeholder_a, placeholder_b); - EXPECT_THAT(add.node(), NodeWith(Op("Add"), Name("add"), - Inputs(NodeWith(Name("placeholder_a")), - NodeWith(Name("placeholder_b"))))); + EXPECT_THAT(add.node(), + NodeWith(Op("Add"), Name("add"), + Inputs(Out(NodeWith(Name("placeholder_a"))), + Out(NodeWith(Name("placeholder_b")))))); EXPECT_EQ(Explain(add.node(), NodeWith(Inputs())), "\nexpected 0 inputs but node has 2"); EXPECT_EQ( - Explain(add.node(), NodeWith(Inputs(NodeWith(Name("blah")), _))), + Explain(add.node(), NodeWith(Inputs(Out(NodeWith(Name("blah"))), _))), "\ninput 0 does not match expected:\nname: blah, \nsource does not match " "expected name: blah\n\t\nexpected name blah but found placeholder_a"); EXPECT_EQ( - Explain(add.node(), NodeWith(Inputs(_, NodeWith(Name("blah"))))), + Explain(add.node(), NodeWith(Inputs(_, Out(NodeWith(Name("blah")))))), "\ninput 1 does not match expected:\nname: blah, \nsource does not match " "expected name: blah\n\t\nexpected name blah but found placeholder_b"); } @@ -174,6 +179,36 @@ TEST(NodeMatchers, AssignedDevice) { "/job:localhost/replica:0/task:0/device:CPU:0 but found \"\""); } +TEST(NodeMatchers, OutputIndices) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output pred = ops::Placeholder(root.WithOpName("pred"), DT_BOOL); + + Output data = ops::Placeholder(root.WithOpName("data"), DT_FLOAT); + ops::Switch sw(root.WithOpName("switch"), data, pred); + Output add = ops::Add(root.WithOpName("add"), sw.output_true, + ops::Placeholder(root.WithOpName("addend"), DT_FLOAT)); + + EXPECT_THAT(add.node(), NodeWith(Inputs(Out(1, NodeWith(Op("Switch"))), _))); + EXPECT_EQ( + Explain(add.node(), NodeWith(Inputs(Out(0, NodeWith(Op("Switch"))), _))), + "\ninput 0 does not match expected:\nop: Switch, \nexpected output slot " + "to be 0 but found 1"); +} + +TEST(NodeMatchers, Attrs) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output enter = ops::internal::Enter( + root.WithOpName("enter"), + ops::Placeholder(root.WithOpName("data"), DT_FLOAT), "frame_name", + ops::internal::Enter::Attrs{}.IsConstant(true)); + EXPECT_THAT(enter.node(), NodeWith(Attr("is_constant", true))); + EXPECT_EQ(Explain(enter.node(), NodeWith(Attr("is_constant", false))), + "attribute named is_constant does not match value; expected: " + "\"false\", found: \"true\""); + EXPECT_EQ(Explain(enter.node(), NodeWith(Attr("missing_attr", false))), + "did not find attribute named \"missing_attr\" in node"); +} + } // namespace } // namespace testing } // namespace tensorflow diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc index bcd1a29b1ff789b5674a21ff66cc6d23a809afc5..95d12e95fd9a0d1cca513ee74a0651ea69eba89e 100644 --- a/tensorflow/compiler/jit/ops/xla_ops.cc +++ b/tensorflow/compiler/jit/ops/xla_ops.cc @@ -54,6 +54,7 @@ REGISTER_OP("XlaClusterOutput") REGISTER_OP("_XlaCompile") .Input("constants: Tconstants") .Attr("Tconstants: list(type) >= 0") + .Attr("must_compile: bool") .Input("args: Targs") .Attr("Targs: list(type) >= 0") .Input("resources: Nresources * resource") @@ -71,8 +72,12 @@ that _XlaRun can use to look up the LocalExecutable and execute it. key: A key that can be used to look up the local executable compiled by the node and associated metadata. -compilation_successful: True iff the compilation was successful. Always true -for now. +compilation_successful: If the `must_compile` attr is false the _XlaCompile op + can decide not to compile the clusters based on some profitability + heuristics. In that case `compilation_successful` is false if _XlaCompile + chose not to compile the cluster. If the `must_compile` attr is true then + _XlaCompile always attempts to compile the cluster and + `compilation_successful` is always true. )"); REGISTER_OP("_XlaRun") diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 10fc9e85d927ffe2416d6d9e6dfd24b286fbf1a0..5b9610322336acbcede0bef0538043b8ff917c16 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -15,17 +15,19 @@ limitations under the License. #include "tensorflow/compiler/jit/partially_decluster_pass.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { namespace { -Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet* result, +Status FindNodesToDecluster(const Graph& graph, + absl::flat_hash_set* result, absl::Span post_order) { // Find nodes that have at least one user outside their cluster that expects // hostmem output. These nodes should be cloned to outside the cluster to @@ -171,7 +173,7 @@ Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) { GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(), /*edge_filter=*/NotBackedge); - gtl::FlatSet nodes_to_partially_decluster; + absl::flat_hash_set nodes_to_partially_decluster; TF_RETURN_IF_ERROR( FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order)); @@ -205,18 +207,27 @@ bool IsIntraClusterEdge(const Edge& edge) { return src_cluster_name.has_value() && src_cluster_name == dst_cluster_name; } -Status MustCompileNode(const Node* n, bool* result) { +bool IsMustCompileDevice(const DeviceType& device_type) { + const XlaOpRegistry::DeviceRegistration* registration; + if (XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { + return registration->requires_compilation; + } + + return false; +} + +Status MustCompileNode(const Node* n, bool* must_compile) { DeviceType device_type(""); TF_RETURN_IF_ERROR( DeviceToDeviceType(n->assigned_device_name(), &device_type)); - const XlaOpRegistry::DeviceRegistration* registration; - if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { - *result = false; - } else { - *result = registration->requires_compilation; + if (IsMustCompileDevice(device_type)) { + *must_compile = true; + return Status::OK(); } + // We must compile `n` if it does not have a TensorFlow kernel. + *must_compile = !FindKernelDef(device_type, n->def(), nullptr, nullptr).ok(); return Status::OK(); } diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index 0feb73a89e7050e8c413e5a733da1d87775b0ba3..74d5ef57184197ad6e9e5048722e84863756a3f5 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/node_def_util.h" @@ -405,5 +406,36 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) { } } +TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output dynamic_slice_operand = + ops::Placeholder(s.WithOpName("dynamic_slice_operand"), DT_INT32, + ops::Placeholder::Attrs{}); + Output dynamic_slice_begin = ops::Placeholder( + s.WithOpName("dynamic_slice_begin"), DT_INT32, ops::Placeholder::Attrs{}); + Output dynamic_slice_size = ops::Placeholder( + s.WithOpName("dynamic_slice_size"), DT_INT32, ops::Placeholder::Attrs{}); + Output dynamic_slice = + ops::XlaDynamicSlice(s.WithOpName("dynamic_slice"), dynamic_slice_operand, + dynamic_slice_begin, dynamic_slice_size); + + Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"), + DT_FLOAT, ops::Placeholder::Attrs{}); + Output reshape = + ops::Reshape(s.WithOpName("reshape"), reshape_input, dynamic_slice); + + AddToCluster({dynamic_slice.node(), reshape.node()}, "cluster_0"); + + std::unique_ptr graph = absl::make_unique(OpRegistry::Global()); + TF_ASSERT_OK(s.ToGraph(graph.get())); + + Node* n = FindNodeByName(*graph, "dynamic_slice"); + ASSERT_NE(n, nullptr); + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0"); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc index 56e35c0059124015266ffabdf583c8724c8e0908..e039d46ec863920eb7deb5bc20525fdab866415c 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc @@ -82,6 +82,7 @@ limitations under the License. #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_join.h" #include "absl/types/optional.h" @@ -89,8 +90,6 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/util/ptr_util.h" @@ -177,7 +176,7 @@ string ResourceOpToString(const ResourceOp& resource_op) { // point. class ResourceOpSet { private: - using Impl = gtl::FlatSet; + using Impl = absl::flat_hash_set; public: ResourceOpSet() = default; diff --git a/tensorflow/compiler/jit/shape_inference.cc b/tensorflow/compiler/jit/shape_inference.cc new file mode 100644 index 0000000000000000000000000000000000000000..80c691fe490c1092315708a2da754d367d585300 --- /dev/null +++ b/tensorflow/compiler/jit/shape_inference.cc @@ -0,0 +1,174 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/shape_inference.h" + +#include "tensorflow/compiler/jit/shape_inference_helpers.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/graph/algorithm.h" + +namespace tensorflow { + +namespace { + +// Converts a shape inference handle to a PartialTensorShape. +Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context, + const shape_inference::ShapeHandle& handle, + PartialTensorShape* shape) { + // The default is already unknown + if (!context->RankKnown(handle)) return Status::OK(); + + std::vector dims(context->Rank(handle)); + for (int32 i = 0; i < dims.size(); ++i) { + dims[i] = context->Value(context->Dim(handle, i)); + } + return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape); +} + +Status PropagateShapes(const Graph& graph, + const std::map& arg_shapes, + ShapeRefiner* shape_refiner) { + // Visits the nodes in topological order (reverse post-order), inferring + // shapes. + // TODO(phawkins): handle cyclic graphs. + std::vector order; + GetReversePostOrder(graph, &order); + + for (Node* n : order) { + // Ignore the status returned by the shape_refiner. We want the best effort + // shapes, even if no shape function is registered for a node. + Status status = shape_refiner->AddNode(n); + if (!status.ok()) { + VLOG(1) << "Shape inference failed for node: " << status; + } + + if (n->type_string() == "_Arg") { + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + auto it = arg_shapes.find(index); + if (it != arg_shapes.end()) { + const InferredShape& arg_shape = it->second; + shape_inference::InferenceContext* context = + shape_refiner->GetContext(n); + + if (arg_shape.handle_type != DT_INVALID) { + shape_inference::ShapeHandle handle; + TF_RETURN_IF_ERROR(context->MakeShapeFromPartialTensorShape( + arg_shape.handle_shape, &handle)); + + // Sets the shape and type of the variable's value. + context->set_output_handle_shapes_and_types( + 0, std::vector{ + {handle, arg_shape.handle_type}}); + } + + shape_inference::ShapeHandle handle; + TF_RETURN_IF_ERROR( + context->MakeShapeFromPartialTensorShape(arg_shape.shape, &handle)); + TF_RETURN_IF_ERROR(shape_refiner->SetShape(n, 0, handle)); + } + } + } + return Status::OK(); +} + +// Store the shapes of the output tensors in a map +Status StoreOutputShapes(const Graph& graph, const ShapeRefiner& shape_refiner, + GraphShapeInfo* shape_info) { + for (const Node* node : graph.nodes()) { + shape_inference::InferenceContext* context = shape_refiner.GetContext(node); + if (!context) continue; + + auto& outputs = (*shape_info)[node->name()]; + outputs.resize(context->num_outputs()); + for (int i = 0; i < context->num_outputs(); ++i) { + auto& output = outputs[i]; + TF_RETURN_IF_ERROR( + ShapeHandleToTensorShape(context, context->output(i), &output.shape)); + + const auto* handle_shapes_and_types = + context->output_handle_shapes_and_types(i); + if (handle_shapes_and_types != nullptr) { + if (handle_shapes_and_types->size() == 1) { + TF_RETURN_IF_ERROR(ShapeHandleToTensorShape( + context, (*handle_shapes_and_types)[0].shape, + &output.handle_shape)); + output.handle_type = (*handle_shapes_and_types)[0].dtype; + } else { + // otherwise, it may be resource like a Queue, which can have + // multiple shapes and types represented by a single handle. + } + } + VLOG(4) << node->name() << " output " << i << " shape" + << output.shape.DebugString() << " handle_type " + << DataTypeString(output.handle_type) << " handle_shape " + << output.handle_shape.DebugString(); + } + } + return Status::OK(); +} + +} // namespace + +Status InferShapes(Graph* graph, const std::map& arg_shapes, + const tensorflow::FunctionLibraryDefinition* fnlib_def, + GraphShapeInfo* shape_info) { + ShapeRefiner shape_refiner(graph->versions(), graph->op_registry()); + shape_refiner.set_require_shape_inference_fns(false); + // TODO(dlibenzi): Verify if it is worth trying to infer shaped within + // functions. Some functions can be called at multiple locations with + // difference shapes, which will trigger a shape inference based on the + // arguments passed at the first call. + // shape_refiner.set_function_library_for_shape_inference(fnlib_def); + + // ShapeRefiner requires that all inputs of a node are present when + // ShapeRefiner::AddNode is called. To get at least some shape information in + // loops, we temporarily remove loop backedges and add them back again after + // the shape inference is complete. + BackEdgeHelper back_edge; + TF_RETURN_IF_ERROR(back_edge.Remove(graph)); + TF_RETURN_IF_ERROR(PropagateShapes(*graph, arg_shapes, &shape_refiner)); + TF_RETURN_IF_ERROR(back_edge.Replace()); + + // Currently information does not flow "backward" from consumers to producers + // in the shape inference, but we consume the shapes in a second pass in case + // backward information flow is added in the future. + return StoreOutputShapes(*graph, shape_refiner, shape_info); +} + +xla::StatusOr MergeInferredShapes(const InferredShape& a, + const InferredShape& b) { + InferredShape result; + TF_RETURN_IF_ERROR(a.shape.MergeWith(b.shape, &result.shape)); + + if (a.handle_type == DT_INVALID) { + result.handle_type = b.handle_type; + } else if (b.handle_type == DT_INVALID) { + result.handle_type = a.handle_type; + } else if (a.handle_type == b.handle_type) { + result.handle_type = a.handle_type; + } else { + return errors::InvalidArgument( + "Mismatched resource types: ", DataTypeString(a.handle_type), " vs. ", + DataTypeString(b.handle_type)); + } + TF_RETURN_IF_ERROR( + a.handle_shape.MergeWith(b.handle_shape, &result.handle_shape)); + return result; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/shape_inference.h b/tensorflow/compiler/jit/shape_inference.h new file mode 100644 index 0000000000000000000000000000000000000000..8668dbca55c2cf84729d81086bde45757e54f8ab --- /dev/null +++ b/tensorflow/compiler/jit/shape_inference.h @@ -0,0 +1,54 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_H_ +#define TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_H_ + +#include +#include + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +struct InferredShape { + // Shape of the argument tensor. + PartialTensorShape shape; + + // If the argument is a resource variable, the type and shape of the + // variable's value. + DataType handle_type = DT_INVALID; + PartialTensorShape handle_shape; +}; +typedef std::unordered_map> GraphShapeInfo; + +// Infer shapes for all Tensors in a graph, and save them in a map. The vector +// for a Node contains the information about each of its outputs. +// TODO(phawkins): this code does not infer accurate shapes for cyclic graphs. +Status InferShapes(Graph* graph, const std::map& arg_shapes, + const tensorflow::FunctionLibraryDefinition* fnlib_def, + GraphShapeInfo* shape_info); + +// Merges two InferredShapes. Return an error if the two shapes cannot be +// merged. +xla::StatusOr MergeInferredShapes(const InferredShape& a, + const InferredShape& b); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_H_ diff --git a/tensorflow/compiler/jit/shape_inference_test.cc b/tensorflow/compiler/jit/shape_inference_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9268172b1c4a4a717b608a52041219d54383a3ff --- /dev/null +++ b/tensorflow/compiler/jit/shape_inference_test.cc @@ -0,0 +1,124 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests for ShapeInference. + +#include "tensorflow/compiler/jit/shape_inference.h" + +#include +#include + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/test_util.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(ShapeInferenceTest, Basics) { + Scope root = Scope::NewRootScope().ExitOnError(); + auto a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT, + ops::Placeholder::Shape({2, 3})); + auto b = ops::Placeholder(root.WithOpName("B"), DT_FLOAT, + ops::Placeholder::Shape({3})); + auto c = ops::Placeholder(root.WithOpName("C"), DT_FLOAT); + auto d = ops::Add(root.WithOpName("D"), a, b); + auto e = ops::Add(root.WithOpName("E"), d, c); + auto f = ops::Neg(root.WithOpName("F"), e); + auto g = ops::AddN(root.WithOpName("G"), std::initializer_list{e, f}); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_CHECK_OK(root.ToGraph(graph.get())); + + GraphShapeInfo shape_info; + TF_ASSERT_OK(InferShapes(graph.get(), /*arg_shapes=*/{}, + /*fnlib_def=*/nullptr, &shape_info)); + + std::map> expected = { + {"A", {PartialTensorShape({2, 3})}}, {"B", {PartialTensorShape({3})}}, + {"C", {PartialTensorShape()}}, {"D", {PartialTensorShape({2, 3})}}, + {"E", {PartialTensorShape()}}, {"F", {PartialTensorShape()}}, + {"G", {PartialTensorShape()}}, + }; + TF_EXPECT_OK(ShapeAnnotationsMatch(*graph, shape_info, expected)); +} + +TEST(ShapeInferenceTest, WhileLoop) { + // Graph: + // x = array_ops.placeholder(dtypes.int32) + // y = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x]) + Graph graph(OpRegistry::Global()); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32, + ops::Placeholder::Shape({})); + + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32, + ops::Placeholder::Shape({})); + auto enter = + ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop"); + // Add an unused Enter node. These should be ignored. + auto enter2 = + ops::internal::Enter(scope.WithOpName("while/Enter2"), source, "aloop"); + auto merge = ops::Merge(scope.WithOpName("while/Merge"), + std::initializer_list{enter, dummy}); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), + 10); + auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); + auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); + auto switch_node = + ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond); + auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"), + switch_node.output_false); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), + switch_node.output_true); + auto identity_shape = + ops::Const(scope.WithOpName("while/Identity/shape"), {}); + auto identity_reshaped = ops::Reshape( + scope.WithOpName("while/Identity/reshaped"), identity, identity_shape); + + auto one = ops::Const( + scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); + auto add = ops::Add(scope.WithOpName("while/add"), identity_reshaped, one); + auto next_iteration = + ops::NextIteration(scope.WithOpName("while/NextIteration"), add); + + auto sink = ops::Identity(scope.WithOpName("sink"), exit); + + // Remove the dummy node and add the loop backedge. + scope.graph()->RemoveNode(dummy.node()); + scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1); + + TF_EXPECT_OK(scope.ToGraph(&graph)); + } + + GraphShapeInfo shape_info; + TF_ASSERT_OK(InferShapes(&graph, /*arg_shapes=*/{}, /*fnlib_def=*/nullptr, + &shape_info)); + std::map> expected = { + {"while/Identity", {PartialTensorShape()}}, + {"while/add", {PartialTensorShape({})}}, + }; + TF_EXPECT_OK(ShapeAnnotationsMatch(graph, shape_info, expected)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/test_util.cc b/tensorflow/compiler/jit/test_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..cada272090a1f613baea8f6d111866d8bb9cd55b --- /dev/null +++ b/tensorflow/compiler/jit/test_util.cc @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/test_util.h" + +#include "tensorflow/compiler/jit/shape_inference.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace tensorflow { + +Status ShapeAnnotationsMatch( + const Graph& graph, const GraphShapeInfo& shape_info, + std::map> expected_shapes) { + for (Node* node : graph.op_nodes()) { + auto sit = shape_info.find(node->name()); + TF_RET_CHECK(sit != shape_info.end()) + << "Missing shape information for node " << node->name(); + std::vector shapes; + for (const auto& output : sit->second) shapes.push_back(output.shape); + + auto it = expected_shapes.find(node->name()); + if (it != expected_shapes.end()) { + if (!PartialTensorShapeUtils::AreIdentical(shapes, it->second)) { + return errors::InvalidArgument( + "Shape mismatch for ", node->name(), ". Expected: ", + PartialTensorShapeUtils::PartialShapeListString(it->second), + ", actual: ", + PartialTensorShapeUtils::PartialShapeListString(shapes)); + } + expected_shapes.erase(it); + } + } + if (!expected_shapes.empty()) { + std::vector missing; + missing.reserve(expected_shapes.size()); + for (const auto& entry : expected_shapes) { + missing.push_back(entry.first); + } + return errors::InvalidArgument("Missing shapes for nodes: ", + str_util::Join(missing, ",")); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/test_util.h b/tensorflow/compiler/jit/test_util.h new file mode 100644 index 0000000000000000000000000000000000000000..0c9fee8f2446d41f792a6cfbf8fc808d9d679c09 --- /dev/null +++ b/tensorflow/compiler/jit/test_util.h @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helper functions for tests. + +#ifndef TENSORFLOW_COMPILER_JIT_TEST_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_TEST_UTIL_H_ + +#include +#include +#include + +#include "tensorflow/compiler/jit/shape_inference.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Tests that the shapes in 'shape_info' for the nodes in `graph` match +// `expected_shapes`. Returns an error if there are nodes in `expected_shapes` +// that do not have shape information. Ignores nodes in `graph` that do not have +// `expected_shapes` entries. +Status ShapeAnnotationsMatch( + const Graph& graph, const GraphShapeInfo& shape_info, + std::map> expected_shapes); + +} // namespace tensorflow + + +#endif // TENSORFLOW_COMPILER_JIT_TEST_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 3aa9e9c7ed2dd3b7480f40e868c6b07192b68294..826e98b96620165604594a22b81cd02422605c12 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -40,6 +40,7 @@ namespace tensorflow { XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client, DeviceType device_type) : client_(client), device_type_(std::move(device_type)) {} + XlaCompilationCache::~XlaCompilationCache() { // Ensure any use of our programs have completed by waiting for all stream // executors to complete. @@ -228,37 +229,45 @@ Status XlaCompilationCache::Compile( const XlaCompiler::Options& options, const NameAttrList& function, const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions& compile_options) { + const XlaCompiler::CompileOptions& compile_options, + CompileMode compile_mode, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable) { + // Set the compile threshold to 1 to implement CompileMode::kStrict. + int64 compile_threshold = + compile_mode == CompileMode::kLazy ? kDefaultCompilationThreshold : 1; return CompileImpl(options, function, constant_args, variable_args, ctx, - compilation_result, executable, compile_options, false); + compile_options, /*compile_single_op=*/false, + /*compile_threshold=*/compile_threshold, + out_compilation_result, out_executable); } Status XlaCompilationCache::CompileSingleOp( const XlaCompiler::Options& options, const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions& compile_options) { + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable) { const NodeDef& def = ctx->op_kernel().def(); NameAttrList name; name.set_name(def.op()); *name.mutable_attr() = def.attr(); return CompileImpl(options, name, constant_args, variable_args, ctx, - compilation_result, executable, compile_options, true); + compile_options, + /*compile_single_op=*/true, /*compile_threshold=*/1, + out_compilation_result, out_executable); } Status XlaCompilationCache::CompileImpl( const XlaCompiler::Options& options, const NameAttrList& function, const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions& compile_options, - bool compile_single_op) { - CHECK_NE(executable, nullptr); + const XlaCompiler::CompileOptions& compile_options, bool compile_single_op, + int64 compile_threshold, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable) { + DCHECK_NE(out_executable, nullptr); VLOG(2) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { @@ -309,9 +318,18 @@ Status XlaCompilationCache::CompileImpl( // TODO(phawkins): this locking will need to be restructured when we implement // cache eviction. mutex_lock entry_lock(entry->mu); + int64 current_request_count = ++entry->request_count; if (!entry->compiled) { VLOG(2) << "Compilation cache miss for signature: " - << SignatureDebugString(signature); + << SignatureDebugString(signature) << " with request count " + << current_request_count << " and compile threshold " + << compile_threshold; + if (current_request_count < compile_threshold) { + *out_compilation_result = nullptr; + *out_executable = nullptr; + return Status::OK(); + } + tensorflow::Env* env = tensorflow::Env::Default(); const uint64 compile_start_us = env->NowMicros(); // Do the actual JIT compilation without holding the lock (it can take @@ -357,8 +375,8 @@ Status XlaCompilationCache::CompileImpl( } } TF_RETURN_IF_ERROR(entry->compilation_status); - *compilation_result = &entry->compilation_result; - *executable = entry->executable.get(); + *out_compilation_result = &entry->compilation_result; + *out_executable = entry->executable.get(); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index 10ad87e38cc4d614e869782329f84351bc3b1f0b..f06a991818db53adb3e5c0cc483c6180128a87e7 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ #define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -24,7 +25,6 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -50,6 +50,11 @@ class XlaCompilationCache : public ResourceBase { XlaCompilationCache(xla::LocalClient* client, DeviceType device_type); ~XlaCompilationCache() override; + enum class CompileMode { + kLazy, + kStrict, + }; + // Compiles a function into a XlaCompiler::CompilationResult that can be used // to execute an XLA Computation. Compilation results are cached. // `function` is the name of a Tensorflow function to compile. @@ -58,6 +63,14 @@ class XlaCompilationCache : public ResourceBase { // `variable_args` is a snapshot of the current values of the // resource variable arguments to `function`; uninitialized variables are // represented by an absent OptionalTensor. + // + // `compile_mode` controls the behavior of the compilation cache on a cache + // miss. If `compile_mode` is `kLazy` then, based on some profitability + // heuristics, the compilation cache may decide not to compile the cluster at + // this time. In this case it returns null into both `out_compilation_result` + // and `out_executable`. If `compile_mode` is `kStrict` then the compilation + // cache always attempts the compilation on a cache miss. + // // The result of compilation is written to `*compilation_result`, which must // be non-null. If `executable` is non-null, also builds an // xla::LocalExecutable and sets `executable` to point to it. The resulting @@ -68,9 +81,10 @@ class XlaCompilationCache : public ResourceBase { const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions& compile_options); + const XlaCompiler::CompileOptions& compile_options, + CompileMode compile_mode, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable); // As above, but calls XlaCompiler::CompileSingleOp instead of // XlaCompiler::CompileFunction. @@ -78,9 +92,9 @@ class XlaCompilationCache : public ResourceBase { const XlaCompiler::Options& options, const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions& compile_options); + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable); xla::LocalClient* client() const { return client_; } const DeviceType& device_type() const { return device_type_; } @@ -89,15 +103,14 @@ class XlaCompilationCache : public ResourceBase { private: // Common implementation of Compile and CompileSingleOp. - Status CompileImpl(const XlaCompiler::Options& options, - const NameAttrList& function, - const std::map& constant_args, - const std::map& variable_args, - OpKernelContext* ctx, - const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions& compile_options, - bool compile_single_op); + Status CompileImpl( + const XlaCompiler::Options& options, const NameAttrList& function, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, + const XlaCompiler::CompileOptions& compile_options, + bool compile_single_op, int64 compile_threshold, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable); // Takes `result` which has been compiled from a Tensorflow subgraph to a // XLA computation already, and generates an XLA LocalExecutable `executable`. @@ -140,6 +153,9 @@ class XlaCompilationCache : public ResourceBase { // Have we tried compiling this entry? bool compiled = false; + // The number of times a compilation with this signature has been requested. + int64 request_count = 0; + // Did compilation succeed? Status compilation_status GUARDED_BY(mu); @@ -152,7 +168,7 @@ class XlaCompilationCache : public ResourceBase { }; mutex compile_cache_mu_; - gtl::FlatMap, Signature::Hash> cache_ + absl::flat_hash_map, Signature::Hash> cache_ GUARDED_BY(compile_cache_mu_); struct CompileStats { @@ -165,9 +181,13 @@ class XlaCompilationCache : public ResourceBase { mutex compile_stats_mu_; // Maps cluster names to compilation statistics for said cluster. - gtl::FlatMap compile_stats_ + absl::flat_hash_map compile_stats_ GUARDED_BY(compile_stats_mu_); + // The number of times a lazy compilation must be requested for a specific + // signature before we attempt to compile it. + static constexpr int64 kDefaultCompilationThreshold = 2; + TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache); }; diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index b98c0cb028ff069278dceda21f4588c0da9086e5..129528bb4428564a130f1eaa30f5d887ce0349dc 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -16,6 +16,8 @@ limitations under the License. // Defines the XlaCompileOnDemandOp. #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" + +#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" @@ -164,8 +166,9 @@ Status XlaCompileOnDemandOp::Compile( XlaCompiler::Options options; options.device_type = metadata.jit_device_type(); options.client = metadata.client(); - options.flib_def = - new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{}); + auto flib_def = absl::make_unique( + OpRegistry::Global(), FunctionDefLibrary{}); + options.flib_def = flib_def.get(); options.shape_representation_fn = metadata.shape_representation_fn(); XlaCompiler::CompileOptions compile_options; @@ -180,7 +183,7 @@ Status XlaCompileOnDemandOp::Compile( std::map variable_args = GetVariables(ctx); return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, - result, executable, compile_options); + compile_options, result, executable); } void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index af83c792e5e11d8596c521c6a3aed332a1f42e5b..090021093d67384521f5fad43b226b5263829c99 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -50,7 +50,7 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) { void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } -XlaTransferManager::XlaTransferManager( +XlaDeviceContext::XlaDeviceContext( std::shared_ptr compute_stream, std::shared_ptr host_to_device_stream, std::shared_ptr device_to_host_stream, xla::LocalClient* client, @@ -75,8 +75,8 @@ XlaTransferManager::XlaTransferManager( } } -Status XlaTransferManager::TransferLiteralToDevice( - const Tensor& host_tensor, Tensor* device_tensor) const { +Status XlaDeviceContext::TransferLiteralToDevice(const Tensor& host_tensor, + Tensor* device_tensor) const { xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(), host_tensor.shape(), &xla_shape)); @@ -112,7 +112,7 @@ Status XlaTransferManager::TransferLiteralToDevice( return Status::OK(); } -void XlaTransferManager::TransferLiteralFromDevice( +void XlaDeviceContext::TransferLiteralFromDevice( Tensor* host_tensor, const Tensor& device_tensor, const StatusCallback& done) const { xla::MutableBorrowingLiteral literal; @@ -134,10 +134,10 @@ void XlaTransferManager::TransferLiteralFromDevice( }); } -void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, - Device* device, - Tensor* device_tensor, - StatusCallback done) const { +void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, + Device* device, + Tensor* device_tensor, + StatusCallback done) const { if (cpu_tensor->NumElements() == 0) { VLOG(2) << "CopyCPUTensorToDevice empty tensor"; done(Status::OK()); @@ -202,11 +202,10 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, done(status); } -void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, - absl::string_view tensor_name, - Device* device, - Tensor* cpu_tensor, - StatusCallback done) { +void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, + absl::string_view tensor_name, + Device* device, Tensor* cpu_tensor, + StatusCallback done) { if (device_tensor->NumElements() == 0) { VLOG(2) << "CopyDeviceTensorToCPU empty tensor"; done(Status::OK()); @@ -250,9 +249,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, done(status); } -void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, - Tensor* dst_tensor, - const StatusCallback& done) { +void XlaDeviceContext::CopyDeviceTensorToDevice(const Tensor& src_tensor, + Tensor* dst_tensor, + const StatusCallback& done) { VLOG(2) << "CopyDeviceTensorToDevice " << reinterpret_cast(src_tensor.tensor_data().data()) << " " @@ -320,36 +319,4 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, } } -XlaDeviceContext::XlaDeviceContext( - std::shared_ptr compute_stream, - std::shared_ptr host_to_device_stream, - std::shared_ptr device_to_host_stream, xla::LocalClient* client, - bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn, - thread::ThreadPool* thread_pool) - : manager_(std::move(compute_stream), std::move(host_to_device_stream), - std::move(device_to_host_stream), client, transfer_as_literal, - std::move(shape_representation_fn), thread_pool) {} - -void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, - Device* device, - Tensor* device_tensor, - StatusCallback done) const { - manager_.CopyCPUTensorToDevice(cpu_tensor, device, device_tensor, done); -} - -void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, - absl::string_view tensor_name, - Device* device, Tensor* cpu_tensor, - StatusCallback done) { - manager_.CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor, - done); -} - -void XlaDeviceContext::CopyDeviceTensorToDevice(const Tensor& src_tensor, - Tensor* dst_tensor, - const StatusCallback& done) { - manager_.CopyDeviceTensorToDevice(src_tensor, dst_tensor, done); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index df824212948ac96a5df5228cecd9a8c864bbec9a..babb60acb5ca547d47825022003b296b1e5d0324 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -44,9 +44,9 @@ class XlaDeviceAllocator : public Allocator { }; // Helper class for managing data transfers between host and XLA devices. -class XlaTransferManager { +class XlaDeviceContext : public DeviceContext { public: - explicit XlaTransferManager( + explicit XlaDeviceContext( std::shared_ptr compute_stream, std::shared_ptr host_to_device_stream, std::shared_ptr device_to_host_stream, @@ -55,10 +55,11 @@ class XlaTransferManager { thread::ThreadPool* thread_pool); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, - Tensor* device_tensor, StatusCallback done) const; + Tensor* device_tensor, + StatusCallback done) const override; void CopyDeviceTensorToCPU(const Tensor* device_tensor, absl::string_view tensor_name, Device* device, - Tensor* cpu_tensor, StatusCallback done); + Tensor* cpu_tensor, StatusCallback done) override; void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, const StatusCallback& done); @@ -94,34 +95,6 @@ class XlaTransferManager { thread::ThreadPool* thread_pool_; }; -// DeviceContext for operators assigned to XlaDevice devices. The -// implementation must inherit from DeviceContext but otherwise just -// wraps the methods in XlaTransferManager. -class XlaDeviceContext : public DeviceContext { - public: - explicit XlaDeviceContext( - std::shared_ptr compute_stream, - std::shared_ptr host_to_device_stream, - std::shared_ptr device_to_host_stream, - xla::LocalClient* client, bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn, - thread::ThreadPool* thread_pool); - - void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, - Tensor* device_tensor, - StatusCallback done) const override; - void CopyDeviceTensorToCPU(const Tensor* device_tensor, - absl::string_view tensor_name, Device* device, - Tensor* cpu_tensor, StatusCallback done) override; - void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, - const StatusCallback& done); - - se::Stream* stream() const override { return manager_.stream(); } - - private: - XlaTransferManager manager_; -}; - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_ diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 6967ad1f03fb5dd962d5b41f0c7ab1dfa42fab94..6a1c43aa96026a991c8a8d016d67b5ca048c293c 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -65,11 +65,13 @@ class XlaAssignVariableOp : public AsyncOpKernel { .HostMemory("resources"), \ KERNEL); -#define REGISTER_XLA_COMPILE_KERNEL(DEVICE, KERNEL, TYPES) \ - REGISTER_KERNEL_BUILDER(Name("_XlaCompile") \ - .Device(DEVICE) \ - .HostMemory("constants") \ - .HostMemory("resources"), \ +#define REGISTER_XLA_COMPILE_KERNEL(DEVICE, KERNEL, TYPES) \ + REGISTER_KERNEL_BUILDER(Name("_XlaCompile") \ + .Device(DEVICE) \ + .HostMemory("constants") \ + .HostMemory("key") \ + .HostMemory("compilation_successful") \ + .HostMemory("resources"), \ KERNEL); #define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \ @@ -208,6 +210,8 @@ class XlaAssignVariableOp : public AsyncOpKernel { .TypeConstraint("T") \ .HostMemory("input"), \ RetvalOp); \ + REGISTER_KERNEL_BUILDER( \ + Name(kDeviceRetOp).Device(DEVICE).TypeConstraint("T"), RetvalOp); \ \ REGISTER_KERNEL_BUILDER( \ Name("RemoteCall").Device(DEVICE).HostMemory("target"), RemoteCallOp); \ diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 19e681af0c940023de2ce82b3b337babe2f3dd5a..8a80639b6391ba9b73fe3143df8f6e44505cec2c 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -25,8 +25,9 @@ namespace tensorflow { const char* const DEVICE_XLA_INTERPRETER = "XLA_INTERPRETER"; const char* const DEVICE_INTERPRETER_XLA_JIT = "XLA_INTERPRETER_JIT"; -constexpr std::array kExecAllTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; +constexpr std::array kExecAllTypes = { + {DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, + DT_BOOL, DT_BFLOAT16}}; class XlaInterpreterDeviceFactory : public DeviceFactory { public: diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 4f6fc4e068e3ba125ddbca264c1affa1f09f5896..0e8ee56ed8979111b66e3886f07994c8b665c388 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -239,7 +239,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( // Copy host -> device. (Empty tensors don't have backing buffers.) // Manually allocate memory using an XlaTensorBuffer so we can allocate // as much memory as the device requires (as given by - // GetByteSizeRequirement). This avoids XlaTransferManager having to + // GetByteSizeRequirement). This avoids XlaDeviceContext having to // reallocate the device buffer later. VLOG(1) << "Constant output tensor on device"; diff --git a/tensorflow/compiler/plugin/README.md b/tensorflow/compiler/plugin/README.md index 9dd0d2bdab5e2c990fd547cef4b657253c545715..07465934aec0364eb03ddfb7f99ea54aaf084fff 100644 --- a/tensorflow/compiler/plugin/README.md +++ b/tensorflow/compiler/plugin/README.md @@ -1,5 +1,4 @@ -3rd party XLA devices ---------------------- +## 3rd party XLA devices This directory is intended as a place for 3rd party XLA devices which are _not_ integrated into the public repository. @@ -9,8 +8,5 @@ can be included as a dependency of the JIT subsystem. For integration into the unit test system, see the files: -- tensorflow/compiler/tests/plugin.bzl -- tensorflow/compiler/xla/tests/plugin.bzl - - -- +- tensorflow/compiler/tests/plugin.bzl +- tensorflow/compiler/xla/tests/plugin.bzl diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 3cf74fa7880c96198f9072ab7488a1cec15c9e5c..d6e3f0817edbc21a1dfdccfd9d075c12f7010d97 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -867,9 +867,9 @@ tf_xla_py_test( tags = ["optonly"], deps = [ ":xla_test", - "//tensorflow/contrib/stateless", "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python:stateless_random_ops", ], ) @@ -894,6 +894,22 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "tensor_list_ops_test", + size = "small", + srcs = ["tensor_list_ops_test.py"], + # TensorList ops are not implemented in the on-demand compilation model yet. + disabled_backends = "cpu_ondemand", + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:list_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python/eager:function", + ], +) + tf_xla_py_test( name = "ternary_ops_test", size = "small", @@ -1028,6 +1044,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "permute_test", + size = "small", + srcs = ["permute_test.py"], + deps = [ + "//tensorflow/compiler/tests:xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:nn_ops", + ], +) + tf_xla_py_test( name = "xla_device_test", size = "small", @@ -1060,6 +1089,7 @@ cuda_py_test( size = "medium", srcs = ["jit_test.py"], additional_deps = [ + ":test_utils", "//tensorflow/contrib/compiler:compiler_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -1078,6 +1108,7 @@ cuda_py_test( size = "small", srcs = ["dense_layer_test.py"], additional_deps = [ + ":test_utils", "//tensorflow/contrib/compiler:compiler_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -1105,6 +1136,7 @@ cc_library( "//tensorflow/core:test", "//tensorflow/core:testlib", "//tensorflow/core/kernels:ops_util", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py index 9390870e07d6b5bd90dbc5c04bac0946595dcf7f..d1b90f098d7d6574999ba0af44b285f5ad5e4f8d 100644 --- a/tensorflow/compiler/tests/dense_layer_test.py +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import os import numpy as np +from tensorflow.compiler.tests import test_utils from tensorflow.contrib.compiler import jit from tensorflow.core.protobuf import config_pb2 from tensorflow.python.layers import layers @@ -30,7 +31,6 @@ from tensorflow.python.platform import test jit_scope = jit.experimental_jit_scope - def GetRunMetadataLabels(run_metadata): """Returns all labels in run_metadata.""" labels = [] @@ -68,13 +68,14 @@ class DenseLayerTest(test.TestCase): config.graph_options.optimizer_options.global_jit_level = ( config_pb2.OptimizerOptions.ON_1) - with self.test_session(config=config) as sess: + with self.session(config=config) as sess: x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) y = layers.dense(x, 3) sess.run(variables.initialize_all_variables()) run_metadata = config_pb2.RunMetadata() - sess.run( + test_utils.RunWithWarmup( + sess, y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, run_metadata=run_metadata, options=config_pb2.RunOptions( @@ -98,7 +99,8 @@ class DenseLayerTest(test.TestCase): sess.run(variables.initialize_all_variables()) run_metadata = config_pb2.RunMetadata() - sess.run( + test_utils.RunWithWarmup( + sess, y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, run_metadata=run_metadata, options=config_pb2.RunOptions( @@ -126,7 +128,8 @@ class DenseLayerTest(test.TestCase): sess.run(variables.initialize_all_variables()) run_metadata = config_pb2.RunMetadata() - sess.run( + test_utils.RunWithWarmup( + sess, y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, run_metadata=run_metadata, options=config_pb2.RunOptions( @@ -138,4 +141,6 @@ class DenseLayerTest(test.TestCase): if __name__ == "__main__": + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_enable_lazy_compilation=true " + + os.environ.get("TF_XLA_FLAGS", "")) test.main() diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 68fdb5caf4c2a496b5058cdda40ca650484a6e0e..d67b16f8e9e7320d5717b0203be340a2356e53d0 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -26,7 +26,6 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.tests import xla_test -from tensorflow.python.compat import compat from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -605,168 +604,205 @@ class ResizeBilinearTest(xla_test.XLATestCase): class NonMaxSuppressionTest(xla_test.XLATestCase): def testNMS128From1024(self): - with compat.forward_compatibility_horizon(2018, 8, 8): - num_boxes = 1024 - boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4") - scores_np = np.random.normal(0.5, 0.1, (num_boxes,)).astype("f4") - - max_output_size = 128 - iou_threshold_np = np.array(0.5, dtype=np.float32) - score_threshold_np = np.array(0.0, dtype=np.float32) - - with self.cached_session() as sess: - boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) - scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) - iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, - iou_threshold_np.shape) - score_threshold = array_ops.placeholder(score_threshold_np.dtype, - score_threshold_np.shape) - with self.test_scope(): - selected_indices = image_ops.non_max_suppression_padded( - boxes=boxes, - scores=scores, - max_output_size=max_output_size, - iou_threshold=iou_threshold, - score_threshold=score_threshold, - pad_to_max_output_size=True) - inputs_feed = { - boxes: boxes_np, - scores: scores_np, - score_threshold: score_threshold_np, - iou_threshold: iou_threshold_np - } - (indices_tf, _) = sess.run(selected_indices, feed_dict=inputs_feed) - - self.assertEqual(indices_tf.size, max_output_size) + num_boxes = 1024 + boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4") + scores_np = np.random.normal(0.5, 0.1, (num_boxes,)).astype("f4") + + max_output_size = 128 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.0, dtype=np.float32) + + with self.cached_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + score_threshold: score_threshold_np, + iou_threshold: iou_threshold_np + } + (indices_tf, _) = sess.run(selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) def testNMS3From6Boxes(self): - with compat.forward_compatibility_horizon(2018, 8, 8): - # Three boxes are selected based on IOU. - boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], - [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] - boxes_np = np.array(boxes_data, dtype=np.float32) - - scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] - scores_np = np.array(scores_data, dtype=np.float32) - - max_output_size = 3 - iou_threshold_np = np.array(0.5, dtype=np.float32) - score_threshold_np = np.array(0.0, dtype=np.float32) - - with self.cached_session() as sess: - boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) - scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) - iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, - iou_threshold_np.shape) - score_threshold = array_ops.placeholder(score_threshold_np.dtype, - score_threshold_np.shape) - with self.test_scope(): - selected_indices = image_ops.non_max_suppression_padded( - boxes=boxes, - scores=scores, - max_output_size=max_output_size, - iou_threshold=iou_threshold, - score_threshold=score_threshold, - pad_to_max_output_size=True) - inputs_feed = { - boxes: boxes_np, - scores: scores_np, - score_threshold: score_threshold_np, - iou_threshold: iou_threshold_np - } - (indices_tf, num_valid) = sess.run( - selected_indices, feed_dict=inputs_feed) - - self.assertEqual(indices_tf.size, max_output_size) - self.assertEqual(num_valid, 3) - self.assertAllClose(indices_tf[:num_valid], [3, 0, 5]) + # Three boxes are selected based on IOU. + boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + + max_output_size = 3 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.0, dtype=np.float32) + + with self.cached_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + score_threshold: score_threshold_np, + iou_threshold: iou_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 3) + self.assertAllClose(indices_tf[:num_valid], [3, 0, 5]) def testNMS3Then2WithScoreThresh(self): # Three boxes are selected based on IOU. # One is filtered out by score threshold. - with compat.forward_compatibility_horizon(2018, 8, 8): - boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], - [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] - boxes_np = np.array(boxes_data, dtype=np.float32) - - scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] - scores_np = np.array(scores_data, dtype=np.float32) - max_output_size = 3 - iou_threshold_np = np.array(0.5, dtype=np.float32) - score_threshold_np = np.array(0.4, dtype=np.float32) - - with self.cached_session() as sess: - boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) - scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) - iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, - iou_threshold_np.shape) - score_threshold = array_ops.placeholder(score_threshold_np.dtype, - score_threshold_np.shape) - with self.test_scope(): - selected_indices = image_ops.non_max_suppression_padded( - boxes=boxes, - scores=scores, - max_output_size=max_output_size, - iou_threshold=iou_threshold, - score_threshold=score_threshold, - pad_to_max_output_size=True) - inputs_feed = { - boxes: boxes_np, - scores: scores_np, - iou_threshold: iou_threshold_np, - score_threshold: score_threshold_np - } - (indices_tf, num_valid) = sess.run( - selected_indices, feed_dict=inputs_feed) - - self.assertEqual(indices_tf.size, max_output_size) - self.assertEqual(num_valid, 2) - self.assertAllClose(indices_tf[:num_valid], [3, 0]) + boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + max_output_size = 3 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.4, dtype=np.float32) + + with self.cached_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + iou_threshold: iou_threshold_np, + score_threshold: score_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 2) + self.assertAllClose(indices_tf[:num_valid], [3, 0]) def testNMS3Then1WithScoreMaxThresh(self): # Three boxes are selected based on IOU. # One is filtered out by score threshold. # One is filtered out by max_output_size. - with compat.forward_compatibility_horizon(2018, 8, 8): - boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], - [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] - boxes_np = np.array(boxes_data, dtype=np.float32) - - scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] - scores_np = np.array(scores_data, dtype=np.float32) - max_output_size = 1 - iou_threshold_np = np.array(0.5, dtype=np.float32) - score_threshold_np = np.array(0.4, dtype=np.float32) - - with self.cached_session() as sess: - boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) - scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) - iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, - iou_threshold_np.shape) - score_threshold = array_ops.placeholder(score_threshold_np.dtype, - score_threshold_np.shape) - with self.test_scope(): - selected_indices = image_ops.non_max_suppression_padded( - boxes=boxes, - scores=scores, - max_output_size=max_output_size, - iou_threshold=iou_threshold, - score_threshold=score_threshold, - pad_to_max_output_size=True) - inputs_feed = { - boxes: boxes_np, - scores: scores_np, - iou_threshold: iou_threshold_np, - score_threshold: score_threshold_np - } - (indices_tf, num_valid) = sess.run( - selected_indices, feed_dict=inputs_feed) - - self.assertEqual(indices_tf.size, max_output_size) - self.assertEqual(num_valid, 1) - self.assertAllClose(indices_tf[:num_valid], [3]) + boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + max_output_size = 1 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.4, dtype=np.float32) + + with self.cached_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + iou_threshold: iou_threshold_np, + score_threshold: score_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 1) + self.assertAllClose(indices_tf[:num_valid], [3]) + + def testSelectFromContinuousOverLap(self): + # Tests that a suppressed box does not itself suppress other boxes. + + boxes_data = [[0, 0, 1, 1], [0, 0.2, 1, 1.2], [0, 0.4, 1, 1.4], + [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 3]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.5, 0.4, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + max_output_size = 3 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.1, dtype=np.float32) + + with self.cached_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + iou_threshold: iou_threshold_np, + score_threshold: score_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 3) + self.assertAllClose(indices_tf[:num_valid], [0, 2, 4]) if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index de68ff0e32cd59e65094c0b7319f8ab213eed4db..8778b54dfaf35003c83cf2ab03e9e218c60c98ed 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import os import numpy as np +from tensorflow.compiler.tests import test_utils from tensorflow.contrib.compiler import jit from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 @@ -36,8 +37,8 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.platform import test -jit_scope = jit.experimental_jit_scope +jit_scope = jit.experimental_jit_scope # Disable rewrites to make sure we don't end up having to update this test # whenever we implement new ones. @@ -77,11 +78,11 @@ def InLabels(labels, substr): return any([substr in x for x in labels]) -def MetadataHasXlaOp(run_metadata): +def MetadataHasXlaRunOp(run_metadata): """Returns true if there are XlaRun kernels in run_metadata's timeline.""" # TODO(phawkins): find a less hacky way to test whether a kernel ran. - return InLabels(RunMetadataLabels(run_metadata), "XlaRun") + return InLabels(RunMetadataLabels(run_metadata), "_XlaRun") class JitLaunchTest(test.TestCase): @@ -108,15 +109,14 @@ class JitLaunchTest(test.TestCase): direct_op = fn(*placeholders) run_metadata = config_pb2.RunMetadata() - compiled = sess.run(compiled_op, - feeds, - run_metadata=run_metadata, - options=config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE)) + compiled = test_utils.RunWithWarmup( + sess, compiled_op, feeds, + config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE), + run_metadata) print("Compiled Result {}".format(compiled)) if require_kernel_launch: - self.assert_(MetadataHasXlaOp(run_metadata)) + self.assert_(MetadataHasXlaRunOp(run_metadata)) direct = sess.run(direct_op, feeds) print("Direct Result {}".format(direct)) @@ -137,7 +137,7 @@ class JitLaunchTest(test.TestCase): a = constant_op.constant(100) # pylint: disable=unused-variable call = KernelWithNoOutputs() # pylint: disable=assignment-from-no-return - sess.run(call, {}) + test_utils.RunWithWarmup(sess, call, {}) def testAliasing(self): """Regression test for compiled functions that return an aliased buffer. @@ -250,17 +250,21 @@ class JitLaunchTest(test.TestCase): dx = np.random.random_sample((batch_size, image_size)).astype(np.float32) with session_lib.Session() as sess: run_metadata = config_pb2.RunMetadata() - output = sess.run(y, {x: dx, - w: dw, - b: db}, - run_metadata=run_metadata, - options=config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE)) + output = test_utils.RunWithWarmup( + sess, + y, { + x: dx, + w: dw, + b: db + }, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) # TODO(phawkins): really we would like to test that there were exactly # two kernel launches. However, we have no reliable way to determine # that. - self.assert_(MetadataHasXlaOp(run_metadata)) + self.assert_(MetadataHasXlaRunOp(run_metadata)) expected = np.square(np.dot(dx, dw) + db) self.assertAllClose(expected, output, rtol=1e-1) @@ -272,7 +276,7 @@ class XlaCompilationTest(test.TestCase): def testReshape(self): """Tests an operator with compile-time constant and non-constant inputs.""" - with self.test_session(config=NoRewriteSessionConfig()) as sess: + with self.session(config=NoRewriteSessionConfig()) as sess: x = array_ops.placeholder(dtypes.float32) y = array_ops.placeholder(dtypes.int32) with jit_scope(): @@ -284,19 +288,22 @@ class XlaCompilationTest(test.TestCase): # statically known as part of the JIT compilation's input graph. z = array_ops.reshape(x, y) run_metadata = config_pb2.RunMetadata() - out = sess.run(z, - {x: np.array([1, 2, 3, 4, 5, 6], np.float32), - y: [-1, 3]}, - run_metadata=run_metadata, - options=config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assert_(MetadataHasXlaOp(run_metadata)) + out = test_utils.RunWithWarmup( + sess, + z, { + x: np.array([1, 2, 3, 4, 5, 6], np.float32), + y: [-1, 3] + }, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assert_(MetadataHasXlaRunOp(run_metadata)) self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32), out) def testIgnoredArguments(self): """Tests that JIT computations can ignore formal parameters.""" - with self.test_session(config=NoRewriteSessionConfig()) as sess: + with self.session(config=NoRewriteSessionConfig()) as sess: x = array_ops.placeholder(dtypes.int32) y = array_ops.placeholder(dtypes.int32) with jit_scope(): @@ -309,18 +316,22 @@ class XlaCompilationTest(test.TestCase): t = math_ops.add(z, z) run_metadata = config_pb2.RunMetadata() - out = sess.run(t, {x: np.int32(7), - y: np.int32(404)}, - run_metadata=run_metadata, - options=config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assert_(MetadataHasXlaOp(run_metadata)) + out = test_utils.RunWithWarmup( + sess, + t, { + x: np.int32(7), + y: np.int32(404) + }, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assert_(MetadataHasXlaRunOp(run_metadata)) self.assertAllClose(28, out) def testLoops(self): """Tests that compilation accepts computations containing loops.""" - with self.test_session(config=NoRewriteSessionConfig()) as session: + with self.session(config=NoRewriteSessionConfig()) as session: x = array_ops.placeholder(dtypes.float32) with jit_scope(): c = lambda i, _: math_ops.less(i, 5) @@ -332,13 +343,13 @@ class XlaCompilationTest(test.TestCase): run_metadata=run_metadata, options=config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assert_(MetadataHasXlaOp(run_metadata)) + self.assert_(MetadataHasXlaRunOp(run_metadata)) self.assertAllClose(result, np.float32(95), rtol=1e-1) def testCond(self): """Tests that compilation handles switch operators.""" - with self.test_session(config=NoRewriteSessionConfig()) as session: + with self.session(config=NoRewriteSessionConfig()) as session: x = array_ops.placeholder(dtypes.float32) y = array_ops.placeholder(dtypes.float32) c = array_ops.placeholder(dtypes.bool) @@ -351,13 +362,17 @@ class XlaCompilationTest(test.TestCase): # deadlock. run_metadata = config_pb2.RunMetadata() - result = session.run(t, {x: np.float32(2), - y: np.float32(4), - c: True}, - run_metadata=run_metadata, - options=config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assert_(MetadataHasXlaOp(run_metadata)) + result = test_utils.RunWithWarmup( + session, + t, { + x: np.float32(2), + y: np.float32(4), + c: True + }, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assert_(MetadataHasXlaRunOp(run_metadata)) self.assertAllClose(result, np.float32(6), rtol=1e-1) def testNestedFunction(self): @@ -379,7 +394,7 @@ class XlaCompilationTest(test.TestCase): inp = array_ops.placeholder(dtypes.float32) out = Entry(inp) - with self.test_session( + with self.session( config=NoRewriteSessionConfig(), graph=g, use_gpu=True) as sess: run_metadata = config_pb2.RunMetadata() val = sess.run(out, @@ -392,7 +407,7 @@ class XlaCompilationTest(test.TestCase): def testLoopDeadlock(self): """Regression test for bug that caused deadlocks in graphs with loops.""" - with self.test_session(config=NoRewriteSessionConfig()) as session: + with self.session(config=NoRewriteSessionConfig()) as session: x = array_ops.placeholder(dtypes.float32) with jit_scope(): y = x + 1.0 @@ -425,11 +440,13 @@ class XlaCompilationTest(test.TestCase): cfg.graph_options.optimizer_options.do_function_inlining = True with session_lib.Session(graph=g, config=cfg) as sess: run_metadata = config_pb2.RunMetadata() - dx_val = sess.run(dx, - feed_dict={x: 100.}, - run_metadata=run_metadata, - options=config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE)) + dx_val = test_utils.RunWithWarmup( + sess, + dx, + feed_dict={x: 100.}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) self.assertAllClose(dx_val, 0.01) return RunMetadataLabels(run_metadata) @@ -475,7 +492,8 @@ class ElementWiseFusionTest(test.TestCase): a7 = a6 + a2 run_metadata = config_pb2.RunMetadata() - output = sess.run( + output = test_utils.RunWithWarmup( + sess, a7, { a1: arg0, a2: arg1 @@ -509,5 +527,60 @@ class ElementWiseFusionTest(test.TestCase): self.assertAllClose(tf_op, tfef_op, rtol=1e-1) +class LazyCompilationTest(test.TestCase): + + def testLazyCompilation(self): + + @function.Defun(compiled=True) + def CompiledFunction(x): + return math_ops.log(x) + + with session_lib.Session(config=NoRewriteSessionConfig()) as sess: + x = array_ops.placeholder(dtypes.float32) + y = CompiledFunction(x) + + run_metadata_before_warmup = config_pb2.RunMetadata() + sess.run( + y, + feed_dict={x: [2., 10.]}, + run_metadata=run_metadata_before_warmup, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assertTrue( + InLabels( + RunMetadataLabels(run_metadata_before_warmup), "_XlaCompile")) + self.assertFalse( + InLabels(RunMetadataLabels(run_metadata_before_warmup), "_XlaRun")) + + # We compile when we see the same shape a second time. + + run_metadata_after_warmup = config_pb2.RunMetadata() + sess.run( + y, + feed_dict={x: [2., 10.]}, + run_metadata=run_metadata_after_warmup, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assertTrue( + InLabels(RunMetadataLabels(run_metadata_after_warmup), "_XlaCompile")) + self.assertTrue( + InLabels(RunMetadataLabels(run_metadata_after_warmup), "_XlaRun")) + + run_metadata_for_new_shape = config_pb2.RunMetadata() + sess.run( + y, + feed_dict={x: [2., 10., 12.]}, + run_metadata=run_metadata_for_new_shape, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assertTrue( + InLabels( + RunMetadataLabels(run_metadata_for_new_shape), "_XlaCompile")) + self.assertFalse( + InLabels(RunMetadataLabels(run_metadata_for_new_shape), "_XlaRun")) + + if __name__ == "__main__": + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_enable_lazy_compilation=true " + + os.environ.get("TF_XLA_FLAGS", "")) test.main() diff --git a/tensorflow/compiler/tests/nullary_ops_test.py b/tensorflow/compiler/tests/nullary_ops_test.py index f985c5d2d96e06fc0117f3935d61b19c9e8562b1..38cb2f83efc48ffcdf5403a23e666963b2ea4da1 100644 --- a/tensorflow/compiler/tests/nullary_ops_test.py +++ b/tensorflow/compiler/tests/nullary_ops_test.py @@ -43,18 +43,37 @@ class NullaryOpsTest(xla_test.XLATestCase): output.run() def testConstants(self): - constants = [ - np.float32(42), - np.array([], dtype=np.float32), - np.array([1, 2], dtype=np.float32), - np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32), - np.array([[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]], - dtype=np.float32), - np.array([[[]], [[]]], dtype=np.float32), - np.array([[[[1]]]], dtype=np.float32), - ] - for c in constants: - self._testNullary(lambda c=c: constant_op.constant(c), expected=c) + for dtype in self.numeric_types: + constants = [ + dtype(42), + np.array([], dtype=dtype), + np.array([1, 2], dtype=dtype), + np.array([7, 7, 7, 7, 7], dtype=dtype), + np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype), + np.array([[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]], + dtype=dtype), + np.array([[[]], [[]]], dtype=dtype), + np.array([[[[1]]]], dtype=dtype), + ] + for c in constants: + self._testNullary(lambda c=c: constant_op.constant(c), expected=c) + + def testComplexConstants(self): + for dtype in self.complex_types: + constants = [ + dtype(42 + 3j), + np.array([], dtype=dtype), + np.ones([50], dtype=dtype) * (3 + 4j), + np.array([1j, 2 + 1j], dtype=dtype), + np.array([[1, 2j, 7j], [4, 5, 6]], dtype=dtype), + np.array([[[1, 2], [3, 4 + 6j], [5, 6]], + [[10 + 7j, 20], [30, 40], [50, 60]]], + dtype=dtype), + np.array([[[]], [[]]], dtype=dtype), + np.array([[[[1 + 3j]]]], dtype=dtype), + ] + for c in constants: + self._testNullary(lambda c=c: constant_op.constant(c), expected=c) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/permute_test.py b/tensorflow/compiler/tests/permute_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e2f6de821b5fd4709d305bcd17ee6ba40b1443fd --- /dev/null +++ b/tensorflow/compiler/tests/permute_test.py @@ -0,0 +1,88 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the DataFormatVecPermute operator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test + + +class XlaPermuteOpTest(xla_test.XLATestCase): + + def _runPermuteAndCompare(self, x, src_format, dst_format, expected): + with self.cached_session() as session: + with self.test_scope(): + placeholder = array_ops.placeholder(dtypes.as_dtype(x.dtype), x.shape) + param = {placeholder: x} + output = nn_ops.data_format_vec_permute( + placeholder, src_format=src_format, dst_format=dst_format) + result = session.run(output, param) + self.assertAllEqual(result, expected) + + def testNHWCToNCHW(self): + for dtype in {np.int32, np.int64}: + x = np.array([7, 4, 9, 3], dtype=dtype) + self._runPermuteAndCompare(x, "NHWC", "NCHW", [7, 3, 4, 9]) + + def testNCHWToNHWC(self): + for dtype in {np.int32, np.int64}: + x = np.array([7, 4, 9, 3], dtype=dtype) + self._runPermuteAndCompare(x, "NCHW", "NHWC", [7, 9, 3, 4]) + + def testNHWCToHWNC(self): + for dtype in {np.int32, np.int64}: + x = np.array([7, 4, 9, 3], dtype=dtype) + self._runPermuteAndCompare(x, "NHWC", "HWNC", [4, 9, 7, 3]) + + def testHWNCToNHWC(self): + for dtype in {np.int32, np.int64}: + x = np.array([7, 4, 9, 3], dtype=dtype) + self._runPermuteAndCompare(x, "HWNC", "NHWC", [9, 7, 4, 3]) + + def testNHWCToNCHW2D(self): + for dtype in {np.int32, np.int64}: + x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=dtype) + self._runPermuteAndCompare(x, "NHWC", "NCHW", + [[7, 4], [5, 1], [9, 3], [4, 5]]) + + def testNHWCToHWNC2D(self): + for dtype in {np.int32, np.int64}: + x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=dtype) + self._runPermuteAndCompare(x, "NHWC", "HWNC", + [[9, 3], [4, 5], [7, 4], [5, 1]]) + + def testHWNCToNHWC2D(self): + for dtype in {np.int32, np.int64}: + x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=dtype) + self._runPermuteAndCompare(x, "HWNC", "NHWC", + [[4, 5], [7, 4], [9, 3], [5, 1]]) + + def testNCHWToNHWC2D(self): + for dtype in {np.int32, np.int64}: + x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=dtype) + self._runPermuteAndCompare(x, "NCHW", "NHWC", + [[7, 4], [4, 5], [5, 1], [9, 3]]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index bddda6f30245d4b8281a77783ec9922d61bd3883..dc119fb0f8a41a3772a8c9508bf2db657f57de88 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -45,6 +45,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/jit/defs.h" @@ -63,7 +64,6 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" @@ -457,7 +457,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, Tensor tensor(dtype, TensorShape(shape)); switch (dtype) { case DT_FLOAT: { - gtl::FlatSet already_generated; + absl::flat_hash_set already_generated; std::uniform_real_distribution distribution(-1.0f, 1.0f); test::FillFn(&tensor, [&](int i) -> float { float generated; @@ -470,7 +470,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_DOUBLE: { - gtl::FlatSet already_generated; + absl::flat_hash_set already_generated; std::uniform_real_distribution distribution(-1.0, 1.0); test::FillFn(&tensor, [&](int i) -> double { double generated; @@ -483,7 +483,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_COMPLEX64: { - gtl::FlatSet> already_generated; + absl::flat_hash_set> already_generated; std::uniform_real_distribution distribution(-1.0f, 1.0f); test::FillFn(&tensor, [&](int i) { complex64 generated; @@ -500,7 +500,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_INT32: { - gtl::FlatSet already_generated; + absl::flat_hash_set already_generated; std::uniform_int_distribution distribution(-(1 << 20), 1 << 20); test::FillFn(&tensor, [&](int i) -> int32 { int32 generated; @@ -513,7 +513,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_INT64: { - gtl::FlatSet already_generated; + absl::flat_hash_set already_generated; std::uniform_int_distribution distribution(-(1LL << 40), 1LL << 40); test::FillFn(&tensor, [&](int i) -> int64 { @@ -527,7 +527,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_BOOL: { - gtl::FlatSet already_generated; + absl::flat_hash_set already_generated; std::bernoulli_distribution distribution; test::FillFn(&tensor, [&](int i) -> bool { bool generated; @@ -1820,7 +1820,7 @@ TEST_F(OpTest, Diag) { do { dims = RandomDims(1); size = TensorShape(dims).num_elements(); - } while (size * size < tf_xla_max_tensor_size); + } while (size * size > tf_xla_max_tensor_size); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Diag").RandomInput(type, dims).Attr("T", type)); }); diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index dbf4beb693ec1766e6b7b5daaed4be4e1d874fba..3e499c2fb176a6d63fe3590e18a4a90e461e096a 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -48,13 +48,32 @@ class XlaSortOpTest(xla_test.XLATestCase): self.assertAllClose(v, result, rtol=1e-3) def testSort(self): - supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32]) + supported_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) for dtype in supported_types.intersection(self.numeric_types): x = np.arange(101, dtype=dtype) np.random.shuffle(x) self._assertOpOutputMatchesExpected( xla.sort, [x], expected=[np.arange(101, dtype=dtype)]) + def testKeyValueSort(self): + supported_key_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) + supported_value_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32, + dtypes.int64.as_numpy_dtype, dtypes.uint64.as_numpy_dtype]) + for key_type in supported_key_types.intersection(self.numeric_types): + for value_type in supported_value_types.intersection(self.numeric_types): + x = np.arange(101, dtype=key_type) + np.random.shuffle(x) + y = (-x).astype(value_type) + self._assertOpOutputMatchesExpected( + xla.key_value_sort, [x, y], + expected=[ + np.arange(101, dtype=key_type), + -np.arange(101, dtype=value_type) + ]) + def testTopK(self): supported_types = set( [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index f3861043b27ebb131554ff49af8c986229fc15ac..b7747414ead7599d885b319b758976328aaf788b 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -23,9 +23,9 @@ import math import numpy as np from tensorflow.compiler.tests import xla_test -from tensorflow.contrib import stateless from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import stateless_random_ops as stateless from tensorflow.python.ops.distributions import special_math from tensorflow.python.platform import test @@ -91,7 +91,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): with self.cached_session() as sess, self.test_scope(): for dtype in self._random_types(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) - x = stateless.stateless_random_uniform( + x = stateless.stateless_random_normal( shape=[10000], seed=seed_t, dtype=dtype) y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) self.assertTrue(np.all(np.isfinite(y))) diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index 78244d0b366d9128a4c59f786e4c5ac12e743b75..46ca371c8abf1cb4710717a183ee12820c4c4ca0 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -920,6 +920,34 @@ class TensorArrayTest(xla_test.XLATestCase): def testTensorArrayEvalEmptyWithDefault(self): self._testTensorArrayEvalEmptyWithDefault() + def _testTensorArrayScatterRead(self, tf_dtype): + with self.cached_session() as session, self.test_scope(): + convert = _make_converter(tf_dtype) + + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, + tensor_array_name="foo", + size=10) + + indices = constant_op.constant([1, 8]) + value = constant_op.constant(convert([[1.0, -1.0], [10.0, -10.0]])) + id0 = array_ops.placeholder(dtypes.int32) + id1 = array_ops.placeholder(dtypes.int32) + + w = ta.scatter(indices, value) + r0 = w.read(id0) + r1 = w.read(id1) + + # Test aggregation of read + read_vals = session.run([r0, r1], feed_dict={id0: 1, id1: 8}) + self.assertAllEqual(convert([1.0, -1.0]), read_vals[0]) + self.assertAllEqual(convert([10.0, -10.0]), read_vals[1]) + + def testTensorArrayScatterRead(self): + for dtype in self.numeric_tf_types: + self._testTensorArrayScatterRead(dtype) + self._testTensorArrayScatterRead(dtypes.bool) + def testTensorArrayScatterReadAndGradients(self): with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( @@ -929,15 +957,18 @@ class TensorArrayTest(xla_test.XLATestCase): indices = constant_op.constant([1, 8]) value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) + id0 = array_ops.placeholder(dtypes.int32) + id1 = array_ops.placeholder(dtypes.int32) w = ta.scatter(indices, value) - r0 = w.read(1) - r1 = w.read(8) + r0 = w.read(id0) + r1 = w.read(id1) # Test combined gradients + aggregation of read(0). grad = gradients_impl.gradients( ys=[r0, r1], xs=[value], grad_ys=[[2.0, 3.0], [4.0, 5.0]]) - read_vals, grad_vals = session.run([[r0, r1], grad]) + read_vals, grad_vals = session.run([[r0, r1], grad], + feed_dict={id0: 1, id1: 8}) self.assertEqual(len(read_vals), 2) self.assertEqual(len(grad_vals), 1) diff --git a/tensorflow/compiler/tests/tensor_list_ops_test.py b/tensorflow/compiler/tests/tensor_list_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5c079d595c440cac644f5461154509abe7b1d1ed --- /dev/null +++ b/tensorflow/compiler/tests/tensor_list_ops_test.py @@ -0,0 +1,96 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ops which manipulate lists of tensors via bridge.""" + +# pylint: disable=g-bad-name +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import list_ops +from tensorflow.python.platform import test + + +def scalar_shape(): + return ops.convert_to_tensor([], dtype=dtypes.int32) + + +class ListOpsTest(xla_test.XLATestCase): + + def testElementShape(self): + with self.cached_session() as sess, self.test_scope(): + dim = array_ops.placeholder(dtypes.int32) + l = list_ops.tensor_list_reserve( + element_shape=(dim, 15), num_elements=20, + element_dtype=dtypes.float32) + e32 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int32) + e64 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int64) + self.assertAllEqual(sess.run(e32, {dim: 10}), (10, 15)) + self.assertAllEqual(sess.run(e64, {dim: 7}), (7, 15)) + + def testPushPop(self): + with self.cached_session() as sess, self.test_scope(): + num = array_ops.placeholder(dtypes.int32) + l = list_ops.tensor_list_reserve( + element_shape=(7, 15), num_elements=num, element_dtype=dtypes.float32) + l = list_ops.tensor_list_push_back( + l, constant_op.constant(1.0, shape=(7, 15))) + l = list_ops.tensor_list_push_back( + l, constant_op.constant(2.0, shape=(7, 15))) + l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + _, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + self.assertAllEqual(sess.run(e2, {num: 10}), 2.0 * np.ones((7, 15))) + self.assertAllEqual(sess.run(e1, {num: 10}), 1.0 * np.ones((7, 15))) + + def testPushPopSeparateLists(self): + with self.cached_session() as sess, self.test_scope(): + num = array_ops.placeholder(dtypes.int32) + l = list_ops.tensor_list_reserve( + element_shape=scalar_shape(), + num_elements=num, + element_dtype=dtypes.float32) + l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) + l2 = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) + l3 = list_ops.tensor_list_push_back(l, constant_op.constant(3.0)) + _, e11 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + l2, e21 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32) + l2, e22 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32) + l3, e31 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32) + l3, e32 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32) + result = sess.run([e11, [e21, e22], [e31, e32]], {num: 20}) + self.assertEqual(result, [1.0, [2.0, 1.0], [3.0, 1.0]]) + + def testEmptyTensorList(self): + dim = 7 + with self.cached_session() as sess, self.test_scope(): + p = array_ops.placeholder(dtypes.int32) + l = list_ops.empty_tensor_list( + element_shape=(p, 15), element_dtype=dtypes.float32) + l = list_ops.tensor_list_push_back( + l, constant_op.constant(1.0, shape=(dim, 15))) + _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Use TensorListReserve instead"): + self.assertEqual(sess.run(e, {p: dim}), 1.0 * np.ones((dim, 15))) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/test_utils.py b/tensorflow/compiler/tests/test_utils.py index 6abde18ea91f16d153a154b94effab037a911c6c..0e77dbf1a79d3dbacb77bab8b8e3df9bcc6287e1 100644 --- a/tensorflow/compiler/tests/test_utils.py +++ b/tensorflow/compiler/tests/test_utils.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin def ConvertBetweenDataFormats(x, data_format_src, data_format_dst): @@ -61,3 +62,14 @@ def PermuteDimsBetweenDataFormats(dims, data_format_src, data_format_dst): dim_map = {d: i for i, d in enumerate(data_format_src)} permuted_dims = [dims[dim_map[d]] for d in data_format_dst] return permuted_dims + + +_JIT_WARMUP_ITERATIONS = 10 + + +def RunWithWarmup(sess, op_to_run, feed_dict, options=None, run_metadata=None): + """Runs a graph a few times to ensure that its clusters are compiled.""" + for _ in xrange(0, _JIT_WARMUP_ITERATIONS): + sess.run(op_to_run, feed_dict, options=options) + return sess.run( + op_to_run, feed_dict, options=options, run_metadata=run_metadata) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index ba1e3b2b4fdbb73e98105ace6571783ef780adf5..f0e7791e9811533502fae0d4dea5a2e1ca2cf33c 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -308,6 +308,7 @@ tf_cc_test( "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:math_ops_op_lib", "//tensorflow/core:protos_all_cc", @@ -635,6 +636,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:ops", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", ], ) @@ -649,6 +651,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", ], ) @@ -659,5 +662,6 @@ cc_library( hdrs = ["side_effect_util.h"], deps = [ "//tensorflow/core:core_cpu", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD index ea8d1b3d14939d4f4fba598318200f71c2eb0270..adcdb6c8f762cb7ea68485167bf7fc8ccb343a51 100644 --- a/tensorflow/compiler/tf2xla/cc/BUILD +++ b/tensorflow/compiler/tf2xla/cc/BUILD @@ -30,14 +30,15 @@ cc_library( tf_gen_op_wrapper_cc( name = "xla_jit_op_gen", - out_ops_file = "ops/xla_jit_op", + include_internal_ops = 1, + out_ops_file = "ops/xla_jit_ops", deps = ["//tensorflow/compiler/jit/ops:xla_ops"], ) cc_library( name = "xla_jit_ops", - srcs = ["ops/xla_jit_op.cc"], - hdrs = ["ops/xla_jit_op.h"], + srcs = ["ops/xla_jit_ops.cc"], + hdrs = ["ops/xla_jit_ops.h"], deps = [ "//tensorflow/cc:const_op", "//tensorflow/cc:ops", diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index db256e577a1f3dd38e04d102f60182023b9d43b2..46649b8cc43016d4a62f49e20256c77ca8accc79 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -695,6 +695,12 @@ Status Conditional::BuildIfNode(Graph* graph, VLOG(3) << "Build output type: " << DataTypeVectorString(out_type); builder.Attr("Tcond", DT_BOOL); + string outside_compilation; + if (GetNodeAttr(predicate_.node->def(), kXlaOutsideCompilationAttrName, + &outside_compilation) + .ok()) { + builder.Attr(kXlaOutsideCompilationAttrName, outside_compilation); + } builder.Device(predicate_.node->assigned_device_name()); // Conditional should be the first input ... builder.Input(NodeDefBuilder::NodeOut(predicate_.node->name(), diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 36c6f5d316428d0d7cf7daa993ca14b52bfe0c05..f818d80022da0bad851c896f2714c15b20b22195 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -79,7 +79,10 @@ Status FunctionalizeControlFlowForFunction( const string& func_name, const string& new_func_name, const protobuf::Map& attrs, FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, - std::map* canonicalized_name_to_new_name) { + std::map>* canonicalized_name_to_new_name, + bool* modified) { + *modified = false; + // Convert the function to Graph. FunctionLibraryRuntime::Handle handle; TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); @@ -91,44 +94,20 @@ Status FunctionalizeControlFlowForFunction( } }); const FunctionBody* body = flr->GetFunctionBody(handle); + Graph* g = body->graph; - // Call graph optimizer. The most important optimization we need is constant - // folding, which will replace ops like Shape/BroadcastGradientArgs with - // constant shape input. Without this optimization, those ops might become - // dynamic input for then/else body function and XLA will complain that input - // is not compile time constant. We enable function inlining as well, because - // otherwise we won't be able to infer shape for any node depending on - // function call nodes. - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_before_opt_", func_name), - *body->graph, fld); - } - // Optimizer accepts std::unique_ptr* as input and might change - // underlying pointer, thus we create a new Graph and copy from body->graph. - std::unique_ptr optimized_graph(new Graph(fld)); - CopyGraph(*body->graph, optimized_graph.get()); - OptimizerOptions opts; - opts.set_opt_level(OptimizerOptions::L0); - opts.set_do_function_inlining(true); - opts.set_do_constant_folding(true); - GraphOptimizer optimizer(opts); - auto cf_consider_fn = [](const Node* n) { - // Skip SymbolicGradient op when doing constant folding. - // Enabling SymbolicGradient op in constant folding requires - // flr->device() to be non-null, and here we have not constructed - // proper Device object yet (it will be constructed in XlaCompiler). - return n->type_string() != FunctionLibraryDefinition::kGradientOp; - }; - optimizer.Optimize(flr, flr->env(), - /*device=*/nullptr, &optimized_graph, - /*shape_map=*/nullptr, /*cse_consider_fn=*/nullptr, - cf_consider_fn); - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_after_opt_", func_name), - *optimized_graph, fld); + // Check if the graph has Switch or Merge node. + bool has_switch_or_merge = false; + for (Node* n : body->graph->nodes()) { + if (n->type_string() == "Switch" || n->type_string() == "Merge") { + has_switch_or_merge = true; + break; + } } + // We cannot return here directly if the graph has no Switch/Merge. + // It might contain function call nodes, or If/While nodes with Switch/Merge + // in function body. We still need to rewrite those functions and modify + // corresponding nodes. // If any node has associated functions, functionalize them first. // Gather nodes with associated functions first, because rewriting those nodes @@ -136,8 +115,8 @@ Status FunctionalizeControlFlowForFunction( // it. std::vector>> nodes_to_associated_functions; - for (auto* n : optimized_graph->nodes()) { - auto associated_functions = GetAssociatedFunctions(*n, flr); + for (auto* n : g->nodes()) { + auto associated_functions = GetAssociatedFunctions(*n, fld); if (!associated_functions.empty()) { nodes_to_associated_functions.push_back({n, associated_functions}); } @@ -151,10 +130,15 @@ Status FunctionalizeControlFlowForFunction( Canonicalize(name, AttrSlice(&associated_function.attrs())); auto iter = canonicalized_name_to_new_name->find(canonicalized_name); string new_name; + bool function_modified; if (iter != canonicalized_name_to_new_name->end()) { - // If we already functionalized this function, skip functionalization - // but still rewrite the node. - new_name = iter->second; + // If we already processed this function, check if it was rewritten. If + // the function was rewritten, the entry will be non-empty. Otherwise + // the entry will be empty. + function_modified = iter->second.has_value(); + if (function_modified) { + new_name = iter->second.value(); + } } else { if (associated_function.type() == AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) { @@ -166,42 +150,62 @@ Status FunctionalizeControlFlowForFunction( } TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( name, new_name, associated_function.attrs(), fld, flr, - canonicalized_name_to_new_name)); - (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; + canonicalized_name_to_new_name, &function_modified)); + if (function_modified) { + // If the function was rewritten, add an non-empty entry. So later we + // know we have processed this function, and it was rewritten into + // another function. + (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; + } else { + // If the function was not rewritten, add an empty entry. So later + // we know we have processed this function, and it does not need to be + // rewritten. + (*canonicalized_name_to_new_name)[canonicalized_name] = absl::nullopt; + } + } + if (function_modified) { + *modified = true; + + // Notice that if "n" is a function call, RewriteAssociatedFunction() + // will delete it and create a new node instead, making "n" an invalid + // pointer. That's fine because in that case, associated_functions will + // only have one member and the loop will only run once. + TF_RETURN_IF_ERROR(RewriteAssociatedFunction( + g, n, fld, associated_function, new_name)); } - // Notice that if "n" is a function call, RewriteAssociatedFunction() will - // delete it and create a new node instead, making "n" an invalid pointer. - // That's fine because in that case, associated_functions will only have - // one member and the loop will only run once. - TF_RETURN_IF_ERROR(RewriteAssociatedFunction( - optimized_graph.get(), n, fld, associated_function, new_name)); } } - // Functionalize the function body. - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_before_fdef_", func_name), - *optimized_graph, fld); - } - TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld)); - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_after_fdef_", func_name), - *optimized_graph, fld); + if (has_switch_or_merge) { + *modified = true; + + // Functionalize the function body. + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_before_fdef_", func_name), + *g, fld); + } + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld)); + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g, + fld); + } } - FunctionDef functionalized_fdef; - TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name, - &functionalized_fdef)); - // Add rewritten FunctionDef into library. - if (func_name == new_func_name) { - VLOG(2) << "Replacing function " << func_name; + if (*modified) { + // Add rewritten FunctionDef into library. + FunctionDef functionalized_fdef; TF_RETURN_IF_ERROR( - fld->ReplaceFunction(new_func_name, functionalized_fdef)); - } else { - VLOG(2) << "Adding function " << new_func_name; - TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); + GraphToFunctionDef(*g, new_func_name, &functionalized_fdef)); + if (func_name == new_func_name) { + VLOG(2) << "Replacing function " << func_name; + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(new_func_name, functionalized_fdef)); + } else { + VLOG(2) << "Adding function " << new_func_name; + TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); + } } return ret_status; @@ -222,12 +226,16 @@ Status FunctionalizeControlFlowPass::Run( pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); // Find XLA compile ops and its corresponding FunctionDef. + // TPUCompile op is not in the map because graph rewriting might happen + // multiple times, and we want to avoid functionalize it again. static std::map* kNodeTypeToFunctionAttrMapping = new std::map{ - {"TPUCompile", "function"}, + // TPUReplicate ops are generated by EncapsulateTPUComputationsPass. + {"TPUReplicate", "computation"}, + // XlaLaunch ops are generated by EncapsulateXlaComputationsPass. {"XlaLaunch", "function"}, }; - std::map canonicalized_name_to_new_name; + std::map> canonicalized_name_to_new_name; for (Node* n : graph->nodes()) { auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string()); if (it == kNodeTypeToFunctionAttrMapping->end()) { @@ -242,12 +250,15 @@ Status FunctionalizeControlFlowPass::Run( << ". Corresponding function: " << func.name(); string new_func_name = options.flib_def->UniqueFunctionName( absl::StrCat(func.name(), "_f15n_")); + bool modified; TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( func.name(), new_func_name, func.attr(), options.flib_def, flr, - &canonicalized_name_to_new_name)); - n->ClearAttr(func_attr); - func.set_name(new_func_name); - n->AddAttr(func_attr, func); + &canonicalized_name_to_new_name, &modified)); + if (modified) { + n->ClearAttr(func_attr); + func.set_name(new_func_name); + n->AddAttr(func_attr, func); + } } } diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index 7c3ad448ef546dd1ab2640a57d7d1d73ca3768ad..d87436a7b4ac37c74d0f0df921779c8716290013 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -523,6 +523,12 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, builder.Attr("T", arg_types); builder.Attr("cond", cond_name); builder.Attr("body", body_name); + string outside_compilation; + if (GetNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName, + &outside_compilation) + .ok()) { + builder.Attr(kXlaOutsideCompilationAttrName, outside_compilation); + } std::vector inputs; for (int i = 0; i < frame->args.size(); ++i) { const Arg& arg = frame->args[i]; diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 3e823254d3d52e88552712b4f53fa4449586cd20..9ee4178f5c213e919255bb33e9b15800a77256e6 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -40,6 +40,7 @@ tf_kernel_library( "dynamic_stitch_op.cc", "elu_op.cc", "extract_image_patches_op.cc", + "fake_param_op.cc", "fake_quantize_ops.cc", "fft_ops.cc", "fill_op.cc", @@ -62,6 +63,7 @@ tf_kernel_library( "one_hot_op.cc", "pack_op.cc", "pad_op.cc", + "permute_op.cc", "pooling_ops.cc", "qr_op.cc", "quantize_and_dequantize_op.cc", @@ -94,6 +96,7 @@ tf_kernel_library( "stateless_random_ops.cc", "strided_slice_op.cc", "tensor_array_ops.cc", + "tensor_list_ops.cc", "tile_ops.cc", "topk_op.cc", "training_ops.cc", @@ -119,6 +122,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:batch_dot", + "//tensorflow/compiler/tf2xla/lib:broadcast", "//tensorflow/compiler/tf2xla/lib:cholesky", "//tensorflow/compiler/tf2xla/lib:qr", "//tensorflow/compiler/tf2xla/lib:random", @@ -157,6 +161,7 @@ tf_kernel_library( "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:conv_ops", "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:list_kernels", "//tensorflow/core/kernels:no_op", "//tensorflow/core/kernels:ops_util", "//tensorflow/core/kernels:pooling_ops", diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index a988d3c33ed808b022f67882c8ae5100b7e7a305..47e517a6576d3a848bc41ceb703df2bd778c4a35 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -64,7 +64,7 @@ XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions)); // } static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { - std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); auto y_equals_0 = xla::Eq(y, zero); auto zeros = xla::ZerosLike(x); @@ -84,7 +84,7 @@ XLA_MAKE_BINARY(DivNoNan, // } static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { - std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); if (DataTypeIsUnsigned(dtype)) { return xla::Div(x, y); } @@ -105,7 +105,7 @@ XLA_MAKE_BINARY(FloorDiv, static xla::XlaOp XlogyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { - std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); auto is_zero = xla::Eq(x, zero); return xla::Select(is_zero, zero, xla::Mul(x, xla::Log(y))); @@ -114,7 +114,7 @@ XLA_MAKE_BINARY(Xlogy, XlogyImpl(b, input_type(0), lhs, rhs, broadcast_helper)); static xla::XlaOp XdivyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { - std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); auto is_zero = xla::Eq(x, zero); return xla::Select(is_zero, zero, xla::Div(x, y)); @@ -126,7 +126,7 @@ XLA_MAKE_BINARY(Xdivy, XdivyImpl(b, input_type(0), lhs, rhs, broadcast_helper)); // return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y); static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { - std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); auto same_sign = xla::Eq(xla::Lt(x, zero), xla::Lt(y, zero)); auto trunc_mod = xla::Rem(x, y); diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc index 696c1c39befd5aa2972afb6cfa64905b57a5ab72..9bb11fb67e3e4ddc48d68631c60f96c60b921094 100644 --- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc @@ -13,16 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "absl/algorithm/container.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/bcast.h" namespace tensorflow { namespace { @@ -37,59 +32,9 @@ class BroadcastToOp : public XlaOpKernel { TensorShape output_shape; OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape)); - OP_REQUIRES(context, input_shape.dims() <= output_shape.dims(), - errors::InvalidArgument( - "Input rank (", input_shape.dims(), - ") must be less than or equal to the output rank (", - output_shape.dims(), ")")); - - auto input_dims = input_shape.dim_sizes(); - auto output_dims = output_shape.dim_sizes(); - - // Broadcasting is done right-to-left on right-aligned dimensions; reverse - // the two vectors so elements to be broadcast are aligned. - absl::c_reverse(input_dims); - absl::c_reverse(output_dims); - - std::vector broadcast_dims; - std::vector broadcast_shape; - for (int i = 0; i < output_shape.dims(); ++i) { - if (i < input_shape.dims()) { - OP_REQUIRES( - context, - (output_dims[i] == 0 && input_dims[i] == 0) || - (input_dims[i] != 0 && output_dims[i] % input_dims[i] == 0), - errors::InvalidArgument("invalid shape to broadcast from ", - input_shape.DebugString(), " to ", - output_shape.DebugString())); - - broadcast_dims.push_back(broadcast_shape.size()); - if (output_dims[i] == input_dims[i]) { - broadcast_shape.push_back(output_dims[i]); - } else if (output_dims[i] != input_dims[i]) { - // Add dimensions [I, O/I], which we will later flatten to just - // [O]. We must do this in two phases since XLA broadcasting does not - // support tiling. - broadcast_shape.push_back(input_dims[i]); - broadcast_shape.push_back(output_dims[i] / input_dims[i]); - } - } else { - broadcast_shape.push_back(output_dims[i]); - } - } - absl::c_reverse(broadcast_dims); - int broadcast_shape_size = broadcast_shape.size(); - for (int64& broadcast_dim : broadcast_dims) { - broadcast_dim = broadcast_shape_size - broadcast_dim - 1; - } - absl::c_reverse(broadcast_shape); - xla::XlaOp output = xla::Reshape( - xla::BroadcastInDim(context->Input(0), - xla::ShapeUtil::MakeShape( - context->input_xla_type(0), broadcast_shape), - broadcast_dims), - output_shape.dim_sizes()); - context->SetOutput(0, output); + auto output = BroadcastTo(context->Input(0), output_shape.dim_sizes()); + OP_REQUIRES_OK(context, output.status()); + context->SetOutput(0, output.ValueOrDie()); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index da8cf3fc6fa694f592280f8c249d317827d9cd09..2628ef8e2454976aeff3859fa5dc1d8e106f32e1 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace { @@ -76,6 +77,17 @@ class ConstOp : public XlaOpKernel { return; } break; + case DT_COMPLEX64: + if (proto_.scomplex_val_size() == 2) { + ctx->SetOutput( + 0, + xla::Broadcast(xla::ConstantR0( + b, xla::complex64(proto_.scomplex_val(0), + proto_.scomplex_val(1))), + shape.dim_sizes())); + return; + } + break; case DT_INT32: if (proto_.int_val_size() == 1) { ctx->SetOutput( diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc index ef1015552d181a183d412f9c269dd5ec608b388f..234f7b4a019c9aac4bac4f906ddbae166ecd9a80 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -39,7 +40,8 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { // compute valid broadcast shapes, but rely below on XLA to // automatically perform the broadcast assuming its valid shapes are // a superset of TensorFlow's valid shapes. - BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape)); + BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape), + /*fewer_dims_optimization=*/false); if (!bcast.IsValid()) { ctx->SetStatus(errors::InvalidArgument("Incompatible shapes: ", lhs_shape.DebugString(), " vs. ", @@ -86,51 +88,18 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { } /* static */ std::pair XlaBinaryOp::Broadcast( - xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs, - const BCast& broadcast_helper) { - // Manually construct the broadcasting since MapN does not do - // automatic broadcasting. The bcast helper ensures that - // lhs.reshape(bcast.x_reshape()).broadcast(bcast.x_bcast()) and - // rhs.reshape(bcast.y_reshape()).broadcast(bcast.y_bcast()) have - // the same shape, so can be operated on by MapN. - - // First reshape the inputs, which should be a metadata-only - // operation since we are flattening the dimensions in order. - auto lhs_shaped = xla::Reshape(lhs, broadcast_helper.x_reshape()); - auto rhs_shaped = xla::Reshape(rhs, broadcast_helper.y_reshape()); - - // Next broadcast the necessary input dimensions. We rely on the - // XLA optimizer to be smart about the fact that we are asking - // it to broadcast size 1 on some of these dimensions, to avoid - // adding complexity to this code. - auto lhs_broadcast = xla::Broadcast(lhs_shaped, broadcast_helper.x_bcast()); - int lhs_size = broadcast_helper.x_bcast().size(); - auto rhs_broadcast = xla::Broadcast(rhs_shaped, broadcast_helper.y_bcast()); - int rhs_size = broadcast_helper.y_bcast().size(); - - // Now reshape them to the correct output shape. After the - // broadcast each side is twice as wide as it should be, since the - // broadcast dimensions were prepended to the shape. Reshape - // flattening each original dimension with the prepended broadcast - // dimension. E.g. if we started out with lhs_shaped with shape - // [5,2,3] and x_bcast was [2,1,7] then lhs_broadcast would have - // shape [2,1,7,5,2,3] and we want to reshape it to [10,2,21]. - std::vector lhs_reorder; - for (int i = 0; i < lhs_size; ++i) { - lhs_reorder.push_back(i); - lhs_reorder.push_back(i + lhs_size); + xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper) { + auto lhs_output = BroadcastTo(lhs, broadcast_helper.output_shape()); + if (!lhs_output.ok()) { + xla::XlaOp error = lhs.builder()->ReportError(lhs_output.status()); + return {error, error}; } - auto lhs_output = - xla::Reshape(lhs_broadcast, lhs_reorder, broadcast_helper.output_shape()); - std::vector rhs_reorder; - for (int i = 0; i < rhs_size; ++i) { - rhs_reorder.push_back(i); - rhs_reorder.push_back(i + rhs_size); + auto rhs_output = BroadcastTo(rhs, broadcast_helper.output_shape()); + if (!rhs_output.ok()) { + xla::XlaOp error = rhs.builder()->ReportError(rhs_output.status()); + return {error, error}; } - auto rhs_output = - xla::Reshape(rhs_broadcast, rhs_reorder, broadcast_helper.output_shape()); - - return {lhs_output, rhs_output}; + return {lhs_output.ValueOrDie(), rhs_output.ValueOrDie()}; } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h index 6653944a911588b7bc88d67b8cdd2c17850530f0..516ead4bfe89b4ddeee11dcc6410a838d04f28a9 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -67,8 +67,7 @@ class XlaBinaryOp : public XlaOpKernel { // 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same // shape. static std::pair Broadcast( - xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs, - const BCast& broadcast_helper); + xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper); }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc b/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..ec3463bd58f55c1fc6a8f7c074c8e487d266d7b6 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/core/framework/kernel_def_builder.h" + +namespace tensorflow { + +// This OpKernel implements the FakeParam Op for XLA JIT devices. Create zeros +// with the appropriate shape for FakeParam op. +class XlaFakeParamOp : public XlaOpKernel { + public: + explicit XlaFakeParamOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + DataType dtype; + TensorShape tensor_shape; + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &tensor_shape)); + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, tensor_shape, &shape_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + ctx->SetOutput(0, xla::Zeros(b, shape_)); + } + + private: + xla::Shape shape_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaFakeParamOp); +}; + +REGISTER_XLA_OP(Name("FakeParam"), XlaFakeParamOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index 921b4340c0ac674a5ad7d17aaf54f1cf36975151..6713d6bc921b24b25baddfb3fd7296fffcc3d6ea 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -316,6 +318,70 @@ class AdjustHueOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp); +struct WhileCondFn { + const int64 num_boxes; + const int64 output_size; + + explicit WhileCondFn(int64 num_boxes, int64 output_size) + : num_boxes(num_boxes), output_size(output_size) {} + + xla::StatusOr operator()(absl::Span values, + xla::XlaBuilder* cond_builder) const { + xla::XlaOp row_idx = values[0]; + xla::XlaOp row_in_bounds = + xla::Lt(row_idx, xla::ConstantR0(cond_builder, num_boxes)); + xla::XlaOp num_outputs_so_far = values[1]; + xla::XlaOp results_not_full = xla::Lt( + num_outputs_so_far, xla::ConstantR0(cond_builder, output_size)); + return xla::And(row_in_bounds, results_not_full); + } +}; + +// Process the boxes one-by-one using the iou matrix mask. +// This implementation uses a correct, but greedy, sequential algorithm +// to ensure that suppressed boxes cannot themselves suppress other +// boxes. +struct SuppressBodyFn { + const int64 num_boxes; + + explicit SuppressBodyFn(int64 num_boxes) : num_boxes(num_boxes) {} + + xla::StatusOr> operator()( + absl::Span values, xla::XlaBuilder* builder) const { + auto row_idx = values[0]; + auto num_outputs_so_far = values[1]; + auto iou_mask = values[2]; + auto included_iou = values[3]; + auto zero_r1 = xla::ConstantR1(builder, {0}); + // Determine if current elem is active using a slice. + auto row_idx_r1 = xla::Reshape(row_idx, {1}); + auto active_elem = xla::DynamicSlice(included_iou, row_idx_r1, {1}); + active_elem = xla::Reshape(active_elem, {}); + // Increment output count iff current elem is not suppressed. + num_outputs_so_far = xla::Select( + active_elem, num_outputs_so_far + xla::ConstantR0(builder, 1), + num_outputs_so_far); + // Slice out the row_idx. + auto starts = xla::ConcatInDim(builder, {row_idx_r1, zero_r1}, 0); + auto row_iou = xla::DynamicSlice(iou_mask, starts, {1, num_boxes}); + // Remove the diagonal from consideration. An elem cannot suppress + // itself. + auto update_starts = xla::ConcatInDim(builder, {zero_r1, row_idx_r1}, 0); + row_iou = xla::DynamicUpdateSlice( + row_iou, xla::ConstantR2FromArray2D(builder, {{false}}), + update_starts); + // Create a suppression by inverting polarity. + row_iou = xla::Reshape(row_iou, {num_boxes}); + auto supp_mask = xla::Not(row_iou); + // Update mask iff current elem is not suppressed. + included_iou = xla::Select(xla::Broadcast(active_elem, {num_boxes}), + xla::And(included_iou, supp_mask), included_iou); + row_idx = row_idx + xla::ConstantR0(builder, 1); + return std::vector{row_idx, num_outputs_so_far, iou_mask, + included_iou}; + } +}; + class NonMaxSuppressionOp : public XlaOpKernel { public: explicit NonMaxSuppressionOp(OpKernelConstruction* context) @@ -326,14 +392,12 @@ class NonMaxSuppressionOp : public XlaOpKernel { void Compile(XlaOpKernelContext* context) override { // TODO(b/111646731): Improve scalability of this op, using blocking. - int num_boxes_dim = 0; - int coords_dim = 1; const TensorShape& boxes_shape = context->InputShape("boxes"); OP_REQUIRES(context, TensorShapeUtils::IsMatrix(boxes_shape), errors::InvalidArgument("boxes must be 2-D, currently: ", boxes_shape.DebugString())); - const int64 num_boxes = boxes_shape.dim_size(num_boxes_dim); - OP_REQUIRES(context, boxes_shape.dim_size(coords_dim) == 4, + const int64 num_boxes = boxes_shape.dim_size(0); + OP_REQUIRES(context, boxes_shape.dim_size(1) == 4, errors::InvalidArgument("boxes must have 4 columns", boxes_shape.DebugString())); const TensorShape& scores_shape = context->InputShape("scores"); @@ -347,9 +411,13 @@ class NonMaxSuppressionOp : public XlaOpKernel { OP_REQUIRES(context, pad_to_max_output_size_, errors::InvalidArgument( "XLA compilation requires pad_to_max_output_size == True")); + OP_REQUIRES(context, num_boxes <= kint32max, + errors::InvalidArgument("XLA compilation requires number of " + "boxes to be <= kint32max, got ", + num_boxes)); - xla::XlaOp boxes = context->Input("boxes"); - xla::XlaOp scores = context->Input("scores"); + const xla::XlaOp boxes_input = context->Input("boxes"); + const xla::XlaOp scores_input = context->Input("scores"); int64 output_size; OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &output_size)); OP_REQUIRES( @@ -358,90 +426,113 @@ class NonMaxSuppressionOp : public XlaOpKernel { OP_REQUIRES(context, output_size <= kint32max, errors::InvalidArgument("Need output_size <= kint32Max, got ", output_size)); - xla::XlaOp score_thresh = context->Input("score_threshold"); - xla::XlaOp iou_thresh = context->Input("iou_threshold"); - + const xla::XlaOp score_thresh = context->Input("score_threshold"); + const xla::XlaOp iou_thresh = context->Input("iou_threshold"); xla::XlaBuilder* const builder = context->builder(); // Choose a more convenient layout. - xla::XlaOp boxes_t = xla::Transpose(boxes, {1, 0}); - coords_dim = 0; - num_boxes_dim = 1; - - // Shapes are henceforth [1, num_boxes]. - xla::XlaOp coord_y0 = xla::SliceInDim(boxes_t, - /*start_index=*/0, - /*limit_index=*/1, - /*stride=*/1, - /*dimno=*/coords_dim); - xla::XlaOp coord_x0 = xla::SliceInDim(boxes_t, - /*start_index=*/1, - /*limit_index=*/2, - /*stride=*/1, - /*dimno=*/coords_dim); - xla::XlaOp coord_y1 = xla::SliceInDim(boxes_t, - /*start_index=*/2, - /*limit_index=*/3, - /*stride=*/1, - /*dimno=*/coords_dim); - xla::XlaOp coord_x1 = xla::SliceInDim(boxes_t, - /*start_index=*/3, - /*limit_index=*/4, - /*stride=*/1, - /*dimno=*/coords_dim); - xla::XlaOp y1 = - xla::Select(xla::Le(coord_y0, coord_y1), coord_y0, coord_y1); - xla::XlaOp y2 = - xla::Select(xla::Le(coord_y0, coord_y1), coord_y1, coord_y0); - xla::XlaOp x1 = - xla::Select(xla::Le(coord_x0, coord_x1), coord_x0, coord_x1); - xla::XlaOp x2 = - xla::Select(xla::Le(coord_x0, coord_x1), coord_x1, coord_x0); + const xla::XlaOp boxes = xla::Transpose(boxes_input, {1, 0}); + const xla::XlaOp boxes_sorted = xla::GetTupleElement( + xla::Sort(/*keys=*/-xla::Broadcast(scores_input, {4}), + /*values=*/{boxes}, + /*dimension=*/1), + 1); + // Track the mapping of indices into sorted domain. + const xla::XlaOp iota_indices = xla::Iota(builder, xla::S32, num_boxes); + const xla::XlaOp indices_sort = xla::Sort(-scores_input, {iota_indices}); + const xla::XlaOp indices_sorted = xla::GetTupleElement(indices_sort, 1); + const xla::XlaOp scores = xla::Neg(xla::GetTupleElement(indices_sort, 0)); + + // Shapes are henceforth [1, num_boxes]. 'c_y0' denotes 'coordinate' y0. + const xla::XlaOp c_y0 = xla::Reshape(xla::SliceInDim(boxes_sorted, + /*start_index=*/0, + /*limit_index=*/1, + /*stride=*/1, + /*dimno=*/0), + {num_boxes}); + const xla::XlaOp c_x0 = xla::Reshape(xla::SliceInDim(boxes_sorted, + /*start_index=*/1, + /*limit_index=*/2, + /*stride=*/1, + /*dimno=*/0), + {num_boxes}); + const xla::XlaOp c_y1 = xla::Reshape(xla::SliceInDim(boxes_sorted, + /*start_index=*/2, + /*limit_index=*/3, + /*stride=*/1, + /*dimno=*/0), + {num_boxes}); + const xla::XlaOp c_x1 = xla::Reshape(xla::SliceInDim(boxes_sorted, + /*start_index=*/3, + /*limit_index=*/4, + /*stride=*/1, + /*dimno=*/0), + {num_boxes}); + + xla::XlaOp y1 = xla::Select(xla::Le(c_y0, c_y1), c_y0, c_y1); + xla::XlaOp y2 = xla::Select(xla::Le(c_y0, c_y1), c_y1, c_y0); + xla::XlaOp x1 = xla::Select(xla::Le(c_x0, c_x1), c_x0, c_x1); + xla::XlaOp x2 = xla::Select(xla::Le(c_x0, c_x1), c_x1, c_x0); xla::XlaOp area = (y2 - y1) * (x2 - x1); - // Transpose the 1xN tensors, instead of the NxN tensors. - xla::XlaOp y1_t = xla::Transpose(y1, {1, 0}); - xla::XlaOp y2_t = xla::Transpose(y2, {1, 0}); - xla::XlaOp x1_t = xla::Transpose(x1, {1, 0}); - xla::XlaOp x2_t = xla::Transpose(x2, {1, 0}); - xla::XlaOp area_t = xla::Transpose(area, {1, 0}); + // Shapes are henceforth [1, num_boxes]. + y1 = xla::Broadcast(y1, {1}); + y2 = xla::Broadcast(y2, {1}); + x1 = xla::Broadcast(x1, {1}); + x2 = xla::Broadcast(x2, {1}); + area = xla::Broadcast(area, {1}); // Shapes are henceforth [num_boxes, num_boxes]. - xla::XlaOp i_xmin = xla::Max(x1, x1_t); - xla::XlaOp i_ymin = xla::Max(y1, y1_t); - xla::XlaOp i_xmax = xla::Min(x2, x2_t); - xla::XlaOp i_ymax = xla::Min(y2, y2_t); + xla::XlaOp i_xmin = xla::Max(x1, xla::Transpose(x1, {1, 0})); + xla::XlaOp i_ymin = xla::Max(y1, xla::Transpose(y1, {1, 0})); + xla::XlaOp i_xmax = xla::Min(x2, xla::Transpose(x2, {1, 0})); + xla::XlaOp i_ymax = xla::Min(y2, xla::Transpose(y2, {1, 0})); auto square_zero = xla::ZerosLike(i_xmin); xla::XlaOp i_area = xla::Max(i_xmax - i_xmin, square_zero) * xla::Max(i_ymax - i_ymin, square_zero); - xla::XlaOp u_area = area + area_t - i_area; + xla::XlaOp u_area = area + xla::Transpose(area, {1, 0}) - i_area; xla::XlaOp iou = i_area / u_area; xla::XlaOp iou_thresh_mask = xla::Gt(iou, iou_thresh + square_zero); - xla::XlaOp scores_2d = xla::Reshape(scores, {num_boxes, 1}); - xla::XlaOp score_cmp_mask = - xla::Gt(scores_2d, xla::Transpose(scores_2d, {1, 0})); - xla::XlaOp suppress = xla::And(iou_thresh_mask, score_cmp_mask); - - // Shapes are [num_boxes] after the reduce. - xla::XlaOp included_iou = xla::Not(xla::Reduce( - suppress, - /*init_value=*/xla::ConstantR0(builder, false), - /*computation=*/CreateScalarOrComputation(xla::PRED, builder), - /*dimensions_to_reduce=*/{0})); + xla::XlaOp included_iou = + xla::Broadcast(xla::ConstantR0(builder, true), {num_boxes}); + + std::vector init_values; + init_values.reserve(4); + init_values.push_back(xla::ConstantR0(builder, 0)); // col_idx + init_values.push_back(xla::ConstantR0(builder, 0)); // num_outputs + init_values.push_back(iou_thresh_mask); + init_values.push_back(included_iou); + + auto suppress_loop_result = + XlaWhileLoop(WhileCondFn(num_boxes, output_size), + SuppressBodyFn(num_boxes), init_values, "suppress_loop", + builder) + .ValueOrDie(); + xla::XlaOp included_score = xla::Gt(scores, xla::Broadcast(score_thresh, {num_boxes})); - xla::XlaOp included = xla::And(included_iou, included_score); + xla::XlaOp included = xla::And(included_score, suppress_loop_result[3]); + + // Only consider boxes over which we have iterated. This allows for accurate + // counting. DynamicSlice would require knowledge of the size of the output. + auto valid_elem = xla::Lt( + iota_indices, xla::Broadcast(suppress_loop_result[0], {num_boxes})); + included = xla::And(included, valid_elem); + xla::XlaOp neg_inf = xla::Broadcast(xla::MinValue(builder, xla::F32), {num_boxes}); xla::XlaOp scores_included = xla::Select(included, scores, neg_inf); - + xla::XlaOp output_tuple = TopK(scores_included, output_size); + xla::XlaOp selected_indices_sorted = xla::GetTupleElement(output_tuple, 1); + // Calculate num_valid. + // Note: num_valid cannot be taken from the loop outputs, because outputs + // can be suppressed by score threshold. xla::XlaOp ones_included = xla::Select( included, xla::Broadcast(xla::ConstantR0(builder, 1), {num_boxes}), xla::Broadcast(xla::ConstantR0(builder, 0), {num_boxes})); - // num_valid is scalar. Value should be bound by output_size. xla::XlaOp num_valid_total = xla::Reduce( ones_included, @@ -451,8 +542,17 @@ class NonMaxSuppressionOp : public XlaOpKernel { xla::XlaOp num_valid = xla::Min(num_valid_total, xla::ConstantR0(builder, output_size)); - xla::XlaOp output_tuple = TopK(scores_included, output_size); - xla::XlaOp selected_indices = xla::GetTupleElement(output_tuple, 1); + // Re-index into the original scores input tensor, using a Gather. + // Boxes were suppressed in the sorted domain. + xla::XlaOp selected_indices; + DataType gather_type = context->expected_output_dtype(0); + OP_REQUIRES_OK( + context, + XlaGather(indices_sorted, scores_shape, selected_indices_sorted, + TensorShape({output_size}), + /*axis=*/0, + /*indices_are_nd=*/false, + /*dtype=*/gather_type, DT_INT32, builder, &selected_indices)); context->SetOutput(0, selected_indices); context->SetOutput(1, num_valid); diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index 3d81ae9eb89a80e5b89b180ad77521c5ed15e79d..f210bfbd886e48b8d7972393ed1899491486646c 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -88,20 +88,30 @@ class ArgMaxCustomCallOp : public XlaOpKernel { xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0(dim))); } - xla::Shape xla_shape = - xla::ShapeUtil::MakeShape(xla::S64, output_shape.dim_sizes()); + // The argmax function expects row-major layout. + xla::Shape xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout( + xla::S64, output_shape.dim_sizes()); + std::vector arg_shapes; + for (const xla::XlaOp& arg : args) { + auto shape_status = b.GetShape(arg); + OP_REQUIRES_OK(ctx, shape_status.status()); + xla::Shape arg_shape = shape_status.ConsumeValueOrDie(); + *arg_shape.mutable_layout() = xla::LayoutUtil::MakeDescendingLayout( + xla::ShapeUtil::Rank(arg_shape)); + arg_shapes.push_back(std::move(arg_shape)); + } // Tell XLA to call the custom code, defined in // index_ops_kernel_argmax_float_1d.cc. xla::XlaOp output; switch (input_shape.dims()) { case 1: - output = - xla::CustomCall(&b, "argmax_float_1d_xla_impl", args, xla_shape); + output = xla::CustomCallWithLayout(&b, "argmax_float_1d_xla_impl", args, + xla_shape, arg_shapes); break; case 2: - output = - xla::CustomCall(&b, "argmax_float_2d_xla_impl", args, xla_shape); + output = xla::CustomCallWithLayout(&b, "argmax_float_2d_xla_impl", args, + xla_shape, arg_shapes); break; default: OP_REQUIRES(ctx, false, diff --git a/tensorflow/compiler/tf2xla/kernels/permute_op.cc b/tensorflow/compiler/tf2xla/kernels/permute_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..94b51e1a586c6cf623c181abf200b91851c7ba05 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/permute_op.cc @@ -0,0 +1,98 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { +namespace { + +class DataFormatVecPermuteOp : public XlaOpKernel { + public: + explicit DataFormatVecPermuteOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("src_format", &src_format_)); + OP_REQUIRES( + ctx, src_format_.size() == 4, + errors::InvalidArgument("Data format should have 4 characters")); + TensorFormat data_format; + OP_REQUIRES(ctx, FormatFromString(src_format_, &data_format), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("dst_format", &dst_format_)); + OP_REQUIRES( + ctx, dst_format_.size() == 4, + errors::InvalidArgument("Data format should have 4 characters")); + OP_REQUIRES(ctx, FormatFromString(dst_format_, &data_format), + errors::InvalidArgument("Invalid data format")); + } + void Compile(XlaOpKernelContext* ctx) override { + auto builder = ctx->builder(); + const TensorShape input_tensor_shape = ctx->InputShape(0); + int input_rank = input_tensor_shape.dims(); + OP_REQUIRES(ctx, input_rank == 1 || input_rank == 2, + errors::InvalidArgument( + "Input must be a vector or matrix, but got shape ", + input_tensor_shape.DebugString())); + OP_REQUIRES( + ctx, input_tensor_shape.dim_size(0) == 4, + errors::InvalidArgument( + "First dimension of input must be of size 4, but got shape ", + input_tensor_shape.DebugString())); + if (input_rank == 2) { + OP_REQUIRES( + ctx, input_tensor_shape.dim_size(1) == 2, + errors::InvalidArgument( + "Second dimension of 2D input must be of size 2, but got shape ", + input_tensor_shape.DebugString())); + } + std::vector dst_indices(4, 0); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + if (src_format_[i] == dst_format_[j]) { + dst_indices[i] = j; + break; + } + } + } + auto keys = xla::ConstantR1(builder, absl::Span(dst_indices)); + if (input_rank == 2) { + keys = xla::BroadcastInDim( + keys, xla::ShapeUtil::MakeShape(xla::S32, {4, 2}), {0}); + } + auto sorted = xla::Sort(keys, {ctx->Input(0)}, 0); + auto output = xla::GetTupleElement(sorted, 1); + ctx->SetOutput(0, output); + } + + private: + string src_format_; + string dst_format_; + + TF_DISALLOW_COPY_AND_ASSIGN(DataFormatVecPermuteOp); +}; + +REGISTER_XLA_OP( + Name("DataFormatVecPermute").TypeConstraint("T", {DT_INT32, DT_INT64}), + DataFormatVecPermuteOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index afd5986846705f66eb4c7ced9dbe2f4757f5af7f..7ef6fa305b7f5b5aae187808f856a9273f101e14 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -135,7 +135,7 @@ class RandomShuffleOp : public XlaOpKernel { xla::XlaOp curr = input; for (int i = 0; i < rounds; ++i) { xla::XlaOp keys = xla::RngUniform(zero, max_value, key_shape); - xla::XlaOp sorted = xla::Sort(keys, curr); + xla::XlaOp sorted = xla::Sort(keys, {curr}); curr = xla::GetTupleElement(sorted, 1); } diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc index 8102faad28db71075fb8da269c55edbdb667193e..8eee5b12991fb377203d780cecd8916952bd699a 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc @@ -40,10 +40,16 @@ class ReduceWindowOp : public XlaOpKernel { std::vector window_dimensions; std::vector window_strides; + std::vector base_dilations; + std::vector window_dilations; OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( "window_dimensions", &window_dimensions)); OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides", &window_strides)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("base_dilations", + &base_dilations)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( + "window_dilations", &window_dilations)); const int rank = input_shape.dims(); OP_REQUIRES(context, rank == window_dimensions.size(), @@ -56,6 +62,16 @@ class ReduceWindowOp : public XlaOpKernel { "The size of window_strides must be equal to the input " "rank (", window_strides.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == base_dilations.size(), + errors::InvalidArgument( + "The size of base_dilations must be equal to the input " + "rank (", + base_dilations.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == window_dilations.size(), + errors::InvalidArgument( + "The size of window_dilations must be equal to the input " + "rank (", + window_dilations.size(), " vs. ", rank, ")")); // Build the reducer function. XlaCompiler::Argument reducer_arg; @@ -102,7 +118,8 @@ class ReduceWindowOp : public XlaOpKernel { xla::XlaOp output = xla::ReduceWindowWithGeneralPadding( context->Input(0), context->Input(1), *reducer.computation, - window_dimensions, window_strides, padding); + window_dimensions, window_strides, base_dilations, window_dilations, + padding); context->SetOutput(0, output); } @@ -115,6 +132,8 @@ class ReduceWindowOp : public XlaOpKernel { REGISTER_XLA_OP(Name("XlaReduceWindow") .CompileTimeConstInput("window_dimensions") .CompileTimeConstInput("window_strides") + .CompileTimeConstInput("base_dilations") + .CompileTimeConstInput("window_dilations") .CompileTimeConstInput("padding"), ReduceWindowOp); diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index ab094d7dd1ce9856a3c2854fd2776827d6c4b76f..57afd608de820573821d605cadcc8779474b5fd6 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -104,7 +104,8 @@ class ScanOp : public XlaOpKernel { } auto output = xla::ReduceWindowWithGeneralPadding( XlaHelpers::ConvertElementType(builder, ctx->Input(0), dtype), init, - *reducer, window_dims, window_strides, padding); + *reducer, window_dims, window_strides, + /*base_dilations=*/{}, /*window_dilations=*/{}, padding); output = XlaHelpers::ConvertElementType(builder, output, ctx->input_type(0)); diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index 25a5bcbe1dd27d741ce3b74125ba9ce425ee78f3..0c32b8def0f7b741c93e803f8359b6504087e257 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -18,7 +18,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -55,10 +57,10 @@ Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) { // The type-specific part of the implementation of Range. template -Status CreateRangeTensor(const xla::LiteralSlice& start_literal, - const xla::LiteralSlice& limit_literal, - const xla::LiteralSlice& delta_literal, - Tensor* output) { +xla::StatusOr CreateRangeTensor( + const xla::LiteralSlice& start_literal, + const xla::LiteralSlice& limit_literal, + const xla::LiteralSlice& delta_literal, xla::XlaBuilder* builder) { T start = start_literal.Get({}); T limit = limit_literal.Get({}); T delta = delta_literal.Get({}); @@ -82,14 +84,10 @@ Status CreateRangeTensor(const xla::LiteralSlice& start_literal, ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta)) : std::ceil(std::abs((limit - start) / delta))); - *output = Tensor(DataTypeToEnum::v(), TensorShape({size})); - auto flat = output->flat(); - T val = start; - for (int64 i = 0; i < size; ++i) { - flat(i) = val; - val += delta; - } - return Status::OK(); + return xla::ConstantR0(builder, start) + + xla::ConstantR0(builder, delta) * + xla::Iota(builder, xla::primitive_util::NativeToPrimitiveType(), + size); } class RangeOp : public XlaOpKernel { @@ -115,27 +113,26 @@ class RangeOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &delta)); DataType type = input_type(0); - Tensor output; - Status status; + xla::StatusOr output; switch (type) { case DT_INT32: - status = CreateRangeTensor(start, limit, delta, &output); + output = CreateRangeTensor(start, limit, delta, ctx->builder()); break; case DT_INT64: - status = CreateRangeTensor(start, limit, delta, &output); + output = CreateRangeTensor(start, limit, delta, ctx->builder()); break; case DT_FLOAT: - status = CreateRangeTensor(start, limit, delta, &output); + output = CreateRangeTensor(start, limit, delta, ctx->builder()); break; case DT_DOUBLE: - status = CreateRangeTensor(start, limit, delta, &output); + output = CreateRangeTensor(start, limit, delta, ctx->builder()); break; default: - status = errors::InvalidArgument("Invalid type for Range ", + output = errors::InvalidArgument("Invalid type for Range ", DataTypeString(type)); } - OP_REQUIRES_OK(ctx, status); - ctx->SetConstantOutput(0, output); + OP_REQUIRES_OK(ctx, output.status()); + ctx->SetOutput(0, output.ValueOrDie()); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc index aaeeae01ccb303091a6d37d1aeb4b2a3377dc638..6cfdf4a5ae479e9851454df97160754f122bc6ff 100644 --- a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc @@ -25,11 +25,26 @@ class XlaSortOp : public XlaOpKernel { explicit XlaSortOp(OpKernelConstruction* context) : XlaOpKernel(context) {} void Compile(XlaOpKernelContext* context) override { - context->SetOutput(0, xla::Sort(context->Input(0))); + context->SetOutput(0, xla::Sort(context->Input("input"))); } }; REGISTER_XLA_OP(Name("XlaSort"), XlaSortOp); +class XlaKeyValueSortOp : public XlaOpKernel { + public: + explicit XlaKeyValueSortOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + xla::XlaOp result = + xla::Sort(context->Input("keys"), {context->Input("values")}); + context->SetOutput(0, xla::GetTupleElement(result, 0)); + context->SetOutput(1, xla::GetTupleElement(result, 1)); + } +}; + +REGISTER_XLA_OP(Name("XlaKeyValueSort"), XlaKeyValueSortOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 94108b764fd32fc77520f9a8ea16065c27e6accf..06a560d9471c352065ef7e9f6903ebdca542f5b1 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -123,9 +123,10 @@ Status GetTensorArrayShape(const XlaResource* resource, xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand, const xla::XlaOp& update, absl::Span update_dims, - const xla::XlaOp& start_indices) { + const xla::XlaOp& start_indices, DataType dtype) { xla::XlaOp current = xla::DynamicSlice(operand, start_indices, update_dims); - xla::XlaOp sum = xla::Add(current, update); + xla::XlaOp sum = + dtype == DT_BOOL ? xla::Or(current, update) : xla::Add(current, update); return xla::DynamicUpdateSlice(operand, sum, start_indices); } @@ -222,9 +223,16 @@ class TensorArrayWriteOp : public XlaOpKernel { slice_shape.InsertDim(0, 1LL); auto update = xla::Reshape(value, slice_shape.dim_sizes()); - xla::XlaOp written = - DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); - + xla::XlaOp written; + if (resource->tensor_array_multiple_writes_aggregate()) { + written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), + start_indices, dtype_); + } else { + // TODO(b/117569591): Ideally we would report an error in the case that we + // see multiple writes to the same offset. Unfortunately there is no way + // to report errors at the moment, so we silently overwrite. + written = xla::DynamicUpdateSlice(ta, update, start_indices); + } OP_REQUIRES_OK(ctx, resource->SetValue(written)); ctx->SetOutput(0, flow); } @@ -391,7 +399,11 @@ class TensorArrayScatterOp : public XlaOpKernel { } if (scatter_all_elements_in_order) { - ta = xla::Add(ta, value); + if (dtype_ == DT_BOOL) { + ta = xla::Or(ta, value); + } else { + ta = xla::Add(ta, value); + } } else { auto slice_dims = value_shape.dim_sizes(); slice_dims[0] = 1LL; @@ -414,7 +426,7 @@ class TensorArrayScatterOp : public XlaOpKernel { auto start_indices = xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); - ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); + ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices, dtype_); } } @@ -522,8 +534,13 @@ class TensorArraySplitOp : public XlaOpKernel { value_shape.DebugString(), " vs. ", ta_shape.DebugString())); - OP_REQUIRES_OK(ctx, resource->SetValue(xla::Add( - ta, xla::Reshape(value, ta_shape.dim_sizes())))); + const xla::XlaOp reshape = xla::Reshape(value, ta_shape.dim_sizes()); + if (dtype_ == DT_BOOL) { + ta = xla::Or(ta, reshape); + } else { + ta = xla::Add(ta, reshape); + } + OP_REQUIRES_OK(ctx, resource->SetValue(ta)); ctx->SetOutput(0, flow); } diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..74d4fcc425bdadb70a7bedf2487deaf6c4a4f7b9 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -0,0 +1,226 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA TensorList operators. + +#include +#include + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +Status GetTensorListShape(xla::XlaBuilder* builder, xla::XlaOp op, + TensorShape* tensor_list_shape) { + auto shape_or_status = builder->GetShape(op); + if (!shape_or_status.ok()) { + return shape_or_status.status(); + } + xla::Shape shape = shape_or_status.ValueOrDie(); + TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape)); + return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0), + tensor_list_shape); +} + +class TensorListReserveOp : public XlaOpKernel { + public: + explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape element_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape)); + int64 num_elements; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements)); + + TensorShape tensor_shape; + tensor_shape.AddDim(num_elements); + tensor_shape.AppendShape(element_shape); + + xla::XlaBuilder* b = ctx->builder(); + ctx->SetOutput(0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_), + tensor_shape.dim_sizes()), + xla::ConstantR0(b, 0)})); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListReserveOp); +}; + +REGISTER_XLA_OP(Name("TensorListReserve") + .CompileTimeConstInput("element_shape") + .CompileTimeConstInput("num_elements"), + TensorListReserveOp); + +class EmptyTensorListOp : public XlaOpKernel { + public: + explicit EmptyTensorListOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + ctx->CtxFailure( + errors::InvalidArgument("XLA compilation requires a fixed tensor list " + "size. Use TensorListReserve instead.")); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(EmptyTensorListOp); +}; + +REGISTER_XLA_OP(Name("EmptyTensorList"), EmptyTensorListOp); + +class TensorListElementShapeOp : public XlaOpKernel { + public: + explicit TensorListElementShapeOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("shape_type", &shape_type_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + TensorShape shape; + OP_REQUIRES_OK(ctx, GetTensorListShape(b, ctx->Input(0), &shape)); + shape.RemoveDim(0); + + switch (shape_type_) { + case DT_INT64: + ctx->SetOutput(0, xla::ConstantR1(b, shape.dim_sizes())); + break; + case DT_INT32: { + std::vector size; + for (int64 s : shape.dim_sizes()) { + size.push_back(s); + } + ctx->SetOutput(0, xla::ConstantR1(b, size)); + break; + } + default: + ctx->CtxFailure( + errors::InvalidArgument("Unsupported shape type requested")); + return; + } + } + + private: + DataType shape_type_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListElementShapeOp); +}; + +REGISTER_XLA_OP(Name("TensorListElementShape"), TensorListElementShapeOp); + +class TensorListPushBackOp : public XlaOpKernel { + public: + explicit TensorListPushBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp list = ctx->Input(0); + TensorShape elem_shape = ctx->InputShape(1); + + xla::XlaOp ta = xla::GetTupleElement(list, 0); + xla::XlaOp index = xla::GetTupleElement(list, 1); + xla::XlaOp value = ctx->Input(1); + + // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. + auto start_indices = + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + + TensorShape slice_shape = elem_shape; + slice_shape.InsertDim(0, 1LL); + auto update = xla::Reshape(value, slice_shape.dim_sizes()); + + // TODO(phawkins): We don't check the index is in bounds --- there is no + // error mechanism in XLA. + ctx->SetOutput( + 0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices), + index + xla::ConstantR0(b, 1)})); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListPushBackOp); +}; + +REGISTER_XLA_OP(Name("TensorListPushBack"), TensorListPushBackOp); + +class TensorListPopBackOp : public XlaOpKernel { + public: + explicit TensorListPopBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp state = ctx->Input(0); + + TensorShape shape; + OP_REQUIRES_OK(ctx, GetTensorListShape(b, state, &shape)); + + xla::XlaOp ta = xla::GetTupleElement(state, 0); + xla::XlaOp index = xla::GetTupleElement(state, 1); + + index = index - xla::ConstantR0(b, 1); + + // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. + auto start_indices = + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), + xla::MakeEdgePaddingConfig({{0, shape.dims() - 1}})); + + auto slice_shape = shape.dim_sizes(); + slice_shape[0] = 1LL; + + // TODO(phawkins): We don't check the index is in bounds --- there is no + // error mechanism in XLA. + xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); + // Remove the leading '1' dimension. + std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); + + ctx->SetOutput(0, xla::Tuple(b, {ta, index})); + ctx->SetOutput(1, xla::Reshape(read, value_shape)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListPopBackOp); +}; + +REGISTER_XLA_OP(Name("TensorListPopBack"), TensorListPopBackOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index 93d5996b5eaf10221b1d7067e7650b78cd6b8fef..52f2b36e19edd96f491f6706d1872e0d3af2df3b 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -96,7 +96,11 @@ class TileOp : public XlaOpKernel { // operation broadcast semantics. auto broadcasted_zero = xla::Broadcast( XlaHelpers::Zero(ctx->builder(), ctx->input_type(0)), output_shape); - ctx->SetOutput(0, xla::Add(broadcasted_zero, input)); + if (ctx->input_type(0) == DT_BOOL) { + ctx->SetOutput(0, xla::Or(broadcasted_zero, input)); + } else { + ctx->SetOutput(0, xla::Add(broadcasted_zero, input)); + } return; } diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 8597e7f139d8d32b7e08782e70a4ee44d02618f2..1ce3930fd1cd91f8e8dfb765b49be2dc969d1bd7 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -31,6 +31,22 @@ cc_library( ], ) +cc_library( + name = "broadcast", + srcs = ["broadcast.cc"], + hdrs = ["broadcast.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "cholesky", srcs = ["cholesky.cc"], diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.cc b/tensorflow/compiler/tf2xla/lib/broadcast.cc new file mode 100644 index 0000000000000000000000000000000000000000..3e402ef855cd7c114332d84032bc869232404fc8 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/broadcast.cc @@ -0,0 +1,93 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" + +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" + +namespace tensorflow { + +xla::StatusOr BroadcastTo(xla::XlaOp input, + absl::Span output_dims) { + xla::XlaBuilder* builder = input.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input)); + absl::Span input_dims = + xla::AsInt64Slice(input_shape.dimensions()); + + if (input_dims == output_dims) { + return input; + } + + if (input_dims.size() > output_dims.size()) { + return errors::InvalidArgument( + "Input shape (", xla::ShapeUtil::HumanString(input_shape), + ") must have rank less than or equal to the output shape [", + absl::StrJoin(output_dims, ","), "]"); + } + + std::vector broadcast_dims; + std::vector broadcast_shape; + auto input_it = input_dims.rbegin(); + for (auto output_it = output_dims.rbegin(); output_it != output_dims.rend(); + ++output_it) { + if (input_it != input_dims.rend()) { + if (!(*output_it == 0 && *input_it == 0) && + !(*input_it != 0 && *output_it % *input_it == 0)) { + return errors::InvalidArgument("Invalid shape broadcast from ", + xla::ShapeUtil::HumanString(input_shape), + " to [", absl::StrJoin(output_dims, ","), + "]"); + } + + broadcast_dims.push_back(broadcast_shape.size()); + if (*output_it == *input_it) { + broadcast_shape.push_back(*output_it); + } else if (*output_it != *input_it) { + // Add dimensions [I, O/I], which we will later flatten to just + // [O]. We must do this in two phases since XLA broadcasting does not + // support tiling. + broadcast_shape.push_back(*input_it); + broadcast_shape.push_back(*output_it / *input_it); + } + ++input_it; + } else { + broadcast_shape.push_back(*output_it); + } + } + TF_RET_CHECK(input_it == input_dims.rend()); + + absl::c_reverse(broadcast_dims); + int broadcast_shape_size = broadcast_shape.size(); + for (int64& broadcast_dim : broadcast_dims) { + broadcast_dim = broadcast_shape_size - broadcast_dim - 1; + } + absl::c_reverse(broadcast_shape); + xla::XlaOp output = xla::BroadcastInDim( + input, + xla::ShapeUtil::MakeShape(input_shape.element_type(), broadcast_shape), + broadcast_dims); + if (broadcast_shape != output_dims) { + output = xla::Reshape(output, output_dims); + } + return output; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.h b/tensorflow/compiler/tf2xla/lib/broadcast.h new file mode 100644 index 0000000000000000000000000000000000000000..591e696f06b994a7fdea58bc95ba785f683ce7d1 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/broadcast.h @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_ + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace tensorflow { + +// Broadcasts 'input' up to shape 'output_dims', using TensorFlow broadcasting +// rules. Supports broadcasting a dimension of size x to size x*y, i.e., tiling. +xla::StatusOr BroadcastTo(xla::XlaOp input, + absl::Span output_dims); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_ diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 38dfde165df47ca78a25a068a901cd1071aa55e2..2b1c2ced925d9fee7392986015a6e716a94d356f 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -38,12 +38,10 @@ xla::StatusOr XlaScatter( combiner, xla::XlaBuilder* builder) { TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer)); - TF_RETURN_IF_ERROR(builder->GetShape(updates).status()); + TF_ASSIGN_OR_RETURN(xla::Shape updates_shape, builder->GetShape(updates)); TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices)); absl::Span indices_dims = xla::AsInt64Slice(indices_shape.dimensions()); - absl::Span buffer_dims = - xla::AsInt64Slice(buffer_shape.dimensions()); // If the indices are N-dimensional, the minor dimension of indices contains // the indices to update. Otherwise the indices are all scalars. @@ -81,104 +79,129 @@ xla::StatusOr XlaScatter( } } - // Shape of the non-indexed dimensions of the buffer. - std::vector buffer_shape_post_axes( - buffer_dims.begin() + num_index_dims, buffer_dims.end()); - - // Flatten the major dimensions of indices and updates into a single dimension - // for ease of iteration. - std::vector flat_indices_shape({num_indices}); - if (indices_are_vectors) { - flat_indices_shape.push_back(num_index_dims); + // Example of a 1-D scatter that updates two [3,1] tensors in a tensor of + // shape [3,3]: + // NOTE: ***This case will not be generated by any of the tf.scatter ops.*** + // + // operand = s32[3,3] parameter(0) + // indices = s32[2] parameter(1) + // updates = s32[3,2] parameter(2) + // scatter = s32[3,3] scatter(operand, indices, updates), + // to_apply=update_computation, + // update_window_dims={0}, + // inserted_window_dims={1}, + // scatter_dims_to_operand_dims={1}, + // index_vector_dim=1 + // + // + // Example of a 1-D scatter that updates two [1,3] tensors in a tensor of + // shape [3,3]: + // + // operand = s32[3,3] parameter(0) + // indices = s32[2] parameter(1) + // updates = s32[2,3] parameter(2) + // scatter = s32[3,3] scatter(operand, indices, updates), + // to_apply=update_computation, + // update_window_dims={1}, + // inserted_window_dims={0}, + // scatter_dims_to_operand_dims={0}, + // index_vector_dim=1 + // + // + // Example of an N-D scatter updating slices of shape [1,1,2] in a tensor of + // shape [3,3,2] + // + // operand = s32[3,3,2] parameter(0) + // indices = s32[2,2] parameter(1) + // updates = s32[2,2] parameter(2) + // scatter = s32[3,3,2] scatter(operand, indices, updates), + // to_apply=update_computation, + // update_window_dims={1}, + // inserted_window_dims={0,1}, + // scatter_dims_to_operand_dims={0,1}, + // index_vector_dim=1 + // + // + // Example of a scatter updating slices of shape [] in a tensor of shape [1,1] + // + // operand = s32[1,1] parameter(0) + // indices = s32[1] parameter(1) + // updates = s32[1] parameter(2) + // scatter = s32[1,1] scatter(operand, indices, updates), + // to_apply=update_computation, + // update_window_dims={}, + // inserted_window_dims={0,1}, + // scatter_dims_to_operand_dims={0}, + // index_vector_dim=1 + // Note that updates operand would be broadcasted into [1] in this case. + // + + xla::ScatterDimensionNumbers dim_numbers; + dim_numbers.set_index_vector_dim(indices_are_vectors + ? indices_shape.dimensions_size() - 1 + : indices_shape.dimensions_size()); + + int64 updates_rank = xla::ShapeUtil::Rank(updates_shape); + int64 buffer_rank = xla::ShapeUtil::Rank(buffer_shape); + int64 num_window_dims_in_updates = buffer_rank - num_index_dims; + + // If the rank of `updates` is 0 and does not match the expected rank of + // updates, broadcast `updates` to the expected shape of updates. + auto new_updates = updates; + std::vector expected_updates_dims(indices_dims.begin(), + indices_dims.end()); + for (int64 dim = num_index_dims; dim < buffer_rank; ++dim) { + expected_updates_dims.push_back(buffer_shape.dimensions(dim)); + } + int64 expected_updates_rank = expected_updates_dims.size(); + if (updates_rank == 0 && expected_updates_rank != 0) { + new_updates = xla::Broadcast(updates, expected_updates_dims); + TF_ASSIGN_OR_RETURN(updates_shape, builder->GetShape(new_updates)); + updates_rank = xla::ShapeUtil::Rank(updates_shape); } - std::vector flat_updates_shape({num_indices}); - flat_updates_shape.insert(flat_updates_shape.end(), - buffer_shape_post_axes.begin(), - buffer_shape_post_axes.end()); - - // Construct the initial values of the loop-carried Tensors. - auto flat_indices = xla::Reshape(indices, flat_indices_shape); - auto flat_updates = xla::Reshape(updates, flat_updates_shape); - auto init = {flat_indices, flat_updates, buffer}; - - // Constructs the loop body. The implementation of scatter is essentially: - // for i in range(num_indices): - // index = dynamic-slice(indices, i) - // update = dynamic-slice(updates, i) - // buffer = dynamic-update-slice(buffer, update, index) - auto body_fn = [&](xla::XlaOp i, absl::Span loop_vars, - xla::XlaBuilder* body_builder) { - auto indices = loop_vars[0]; - auto updates = loop_vars[1]; - auto buffer = loop_vars[2]; - - auto zero_index = xla::ConstantLiteral( - body_builder, xla::LiteralUtil::Zero(indices_shape.element_type())); - - // Slice the i-th index from the indices array. - xla::XlaOp index; - auto indices_offset = xla::Reshape(i, {1}); - if (indices_are_vectors) { - indices_offset = xla::Pad(indices_offset, zero_index, - xla::MakeEdgePaddingConfig({{0, 1}})); - - index = xla::DynamicSlice(indices, indices_offset, {1, num_index_dims}); - index = xla::Collapse(index, {0, 1}); - } else { - index = xla::DynamicSlice(indices, indices_offset, {1}); + if (updates_rank > 0) { + for (int64 i = (updates_rank - num_window_dims_in_updates); + i < updates_rank; ++i) { + dim_numbers.add_update_window_dims(i); } + } - // Discard updates with negative indices, since some users expect this. - auto index_in_range = xla::ReduceAll( - xla::Le(zero_index, index), xla::ConstantR0(body_builder, true), - xla::CreateScalarAndComputation(xla::PRED, body_builder)); - - // Make the index in bounds to prevent implementation defined behavior. - index = xla::Max(index, zero_index); - index = xla::Pad( - index, zero_index, - xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); - - // Slice the i-th index from the updates array. - auto updates_offset = xla::Reshape(i, {1}); - updates_offset = xla::Pad( - updates_offset, zero_index, - xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); - std::vector flat_updates_slice_shape({1}); - flat_updates_slice_shape.insert(flat_updates_slice_shape.end(), - buffer_shape_post_axes.begin(), - buffer_shape_post_axes.end()); - auto update = - xla::DynamicSlice(updates, updates_offset, flat_updates_slice_shape); - - // Unflatten the major (iteration) dimensions of the slice to their - // original shape. - std::vector updates_slice_shape(num_index_dims, 1); - updates_slice_shape.insert(updates_slice_shape.end(), - buffer_shape_post_axes.begin(), - buffer_shape_post_axes.end()); - update = xla::Reshape(update, updates_slice_shape); - - // Apply the update to the buffer. If there is a combiner, use it to merge - // the current values with the update. - auto current_value = xla::DynamicSlice(buffer, index, updates_slice_shape); + for (int64 i = 0; i < num_index_dims; ++i) { + dim_numbers.add_inserted_window_dims(i); + dim_numbers.add_scatter_dims_to_operand_dims(i); + } + + // Build the combiner computation. + xla::XlaComputation combiner_computation; + { + xla::XlaBuilder cb("scatter-combiner"); + auto xla_scalar_shape = + xla::ShapeUtil::MakeShape(buffer_shape.element_type(), {}); + auto p0 = xla::Parameter(&cb, 0, xla_scalar_shape, "p0"); + auto p1 = xla::Parameter(&cb, 1, xla_scalar_shape, "p1"); if (combiner) { - update = combiner(current_value, update, body_builder); + combiner(p0, p1, &cb); } - // Use the current value instead of the update if the index is out of - // bounds. - update = xla::Select(index_in_range, update, current_value); - // Apply the update. - buffer = xla::DynamicUpdateSlice(buffer, update, index); - - return std::vector{indices, updates, buffer}; - }; - - TF_ASSIGN_OR_RETURN(auto outputs, - XlaForEachIndex(num_indices, indices_shape.element_type(), - body_fn, init, "scatter", builder)); - return outputs[2]; + combiner_computation = cb.Build().ConsumeValueOrDie(); + } + + VLOG(3) << "Scatter op:"; + VLOG(3) << " Input: " << xla::ShapeUtil::HumanString(buffer_shape); + VLOG(3) << " Indices: " << xla::ShapeUtil::HumanString(indices_shape); + VLOG(3) << " Updates: " << xla::ShapeUtil::HumanString(updates_shape); + VLOG(3) << " Scatter Dimension Numbers: "; + VLOG(3) << " index_vector_dim: " << dim_numbers.index_vector_dim(); + VLOG(3) << " update_window_dims: [" + << absl::StrJoin(dim_numbers.update_window_dims(), ",") << "]"; + VLOG(3) << " inserted_window_dims: [" + << absl::StrJoin(dim_numbers.inserted_window_dims(), ",") << "]"; + VLOG(3) << " scatter_dims_to_operand_dims: [" + << absl::StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ",") + << "]"; + + return xla::Scatter(buffer, indices, new_updates, combiner_computation, + dim_numbers); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h index 13a5f1b850a612bddeeac39bef431c19925351ca..4cf478c4b9b4316f1cf43f45d1bf90afa648fb11 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.h +++ b/tensorflow/compiler/tf2xla/lib/scatter.h @@ -34,7 +34,11 @@ namespace tensorflow { // Otherwise, `indices_are_vectors`, then indices are multidimensional and the // minor dimension of `indices` represents a vector of indices. // -// If any indices are negative, the corresponding update is discarded. +// If `updates` is a scalar, then it will be broadcasted into the expected shape +// of updates. +// +// If any part of the update region is out-of-bounds, the corresponding update +// is discarded. // // If a `combiner` is provided, updates are combined with the existing values in // the buffer using the combiner function. Otherwise, the updates replace the diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 733eeed3c661c9ed683f0fb7fd90f7f997b8dc2b..bd2c0a5ee88869ba60701c0a7ace05857452eed9 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -283,6 +283,8 @@ REGISTER_OP("XlaReduceWindow") .Input("init_value: T") .Input("window_dimensions: Tindices") .Input("window_strides: Tindices") + .Input("base_dilations: Tindices") + .Input("window_dilations: Tindices") .Input("padding: Tindices") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") @@ -354,12 +356,33 @@ Wraps the XLA Sort operator, documented at https://www.tensorflow.org/performance/xla/operation_semantics#sort . -Sorts a tensor. Currently only rank 1 sorts in ascending order are supported. +Sorts a tensor. Currently only sorts in ascending order are supported. input: A `Tensor` of type T. output: A `Tensor` of type T. )doc"); +REGISTER_OP("XlaKeyValueSort") + .Input("keys: K") + .Input("values: V") + .Output("sorted_keys: K") + .Output("sorted_values: V") + .Attr("K: realnumbertype") + .Attr("V: type") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Wraps the XLA Sort operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#sort +. + +Sorts a tensor. Currently only sorts in ascending order are supported. + +keys: A `Tensor` of type K. +values: A `Tensor` of type V. +sorted_keys: A `Tensor` of type K. +sorted_values: A `Tensor` of type V. +)doc"); + // TODO(b/37549631) setting the While Op to always be stateful is too // conservative. REGISTER_OP("XlaWhile") diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 27dd18a9bbd5aceece41aaf61eb185acb537b3b6..5e86b5d8ec0a2690f004bc67decea09185d9cbb6 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -212,9 +212,9 @@ bitcast_convert_type = array_ops.bitcast def broadcast(x, dims, name=None): x = ops.convert_to_tensor(x) - shape = array_ops.concat( - [constant_op.constant(dims), - array_ops.shape(x)], axis=0) + shape = array_ops.concat([constant_op.constant(dims), + array_ops.shape(x)], + axis=0) return array_ops.broadcast_to(x, shape, name=name) @@ -320,6 +320,8 @@ def reduce_window(operand, reducer, window_dimensions, window_strides=None, + base_dilations=None, + window_dilations=None, padding=None, name=None): """Wraps the XLA ReduceWindow operator. @@ -332,22 +334,27 @@ def reduce_window(operand, init: a scalar tensor representing the initial value for the reduction reducer: a reduction function that combines a pair of scalars. window_dimensions: shape of the window, as a list of integers - window_strides: inter-window strides, as a list of integers. Optional; - if omitted, defaults to strides of 1. + window_strides: inter-window strides, as a list of integers. Optional; if + omitted, defaults to strides of 1. padding: padding to apply to 'operand'. List of (low, high) pairs of integers that specify the padding to apply before and after each dimension. Optional; if omitted, defaults to no padding. name: the operator name, or None. + Returns: A tensor that represents the output of the reduce_window operator. """ window_strides = window_strides or [1] * len(window_dimensions) + base_dilations = base_dilations or [1] * len(window_dimensions) + window_dilations = window_dilations or [1] * len(window_dimensions) padding = padding or [(0, 0)] * len(window_dimensions) return gen_xla_ops.xla_reduce_window( input=operand, init_value=init, window_dimensions=window_dimensions, window_strides=window_strides, + base_dilations=base_dilations, + window_dilations=window_dilations, padding=padding, computation=reducer, name=name) @@ -377,4 +384,5 @@ def slice(x, start_dims, limit_dims, strides): sort = gen_xla_ops.xla_sort +key_value_sort = gen_xla_ops.xla_key_value_sort while_loop = gen_xla_ops.xla_while diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index 20f2ce2919701731ef6e90d368b67545af95e8f9..72b240996fb4d9dcb5f5dfd919da618cbae08c16 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "absl/algorithm/container.h" -#include "tensorflow/core/lib/gtl/flatmap.h" +#include "absl/container/flat_hash_map.h" namespace tensorflow { /*static*/ absl::string_view XlaResourceOpInfo::XlaResourceOpKindToString( @@ -30,9 +30,9 @@ namespace tensorflow { } } -static gtl::FlatMap* +static absl::flat_hash_map* CreateResourceOpInfoMap() { - auto* result = new gtl::FlatMap; + auto* result = new absl::flat_hash_map; auto add = [&](absl::string_view op, XlaResourceOpKind op_kind, XlaResourceKind resource_kind) { @@ -103,15 +103,15 @@ CreateResourceOpInfoMap() { return result; } -static const gtl::FlatMap& +static const absl::flat_hash_map& GetStaticResourceOpInfoMap() { - static gtl::FlatMap* op_info_map = - CreateResourceOpInfoMap(); + static absl::flat_hash_map* + op_info_map = CreateResourceOpInfoMap(); return *op_info_map; } const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) { - const gtl::FlatMap& op_infos = + const absl::flat_hash_map& op_infos = GetStaticResourceOpInfoMap(); auto it = op_infos.find(op); return it == op_infos.end() ? nullptr : &it->second; diff --git a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc index a85ef040a7b65c2f6e405c3444eaef3019137b4b..956f597301d28d781a9df7ab2086ed79d4c8bf9d 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -33,7 +34,7 @@ bool HasResourceInputOrOutput(const OpDef& op_def) { } TEST(ResourceOperationTableTest, HaveAllResourceOps) { - gtl::FlatMap known_resource_ops; + absl::flat_hash_map known_resource_ops; for (absl::string_view known_resource_op : resource_op_table_internal::GetKnownResourceOps()) { ASSERT_TRUE( diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc index 6cd7b24592f30d7202b985f3dfd082ea2d85e344..b233e6b2c28e1968bb74901fc684e808ae45ab60 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.cc +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/side_effect_util.h" +#include "absl/strings/numbers.h" #include "tensorflow/core/graph/algorithm.h" namespace tensorflow { @@ -64,4 +65,28 @@ bool HasSideEffectingNodes(const Graph& g) { return false; } +Status ParseHostComputeCoreList(absl::Span list_from_attr, + std::map* host_compute_core) { + for (const auto& hc_core : list_from_attr) { + std::vector parts = str_util::Split(hc_core, ":"); + if (parts.size() != 2) { + return errors::InvalidArgument( + "Malformed host_compute_core entry ", hc_core, + " should be :."); + } + int core; + if (!absl::numbers_internal::safe_strto32_base(parts[1], &core, 10)) { + return errors::InvalidArgument("Malformed host_compute_core entry ", + hc_core, + " part after ':' should be an integer."); + } + if (host_compute_core->find(parts[0]) != host_compute_core->end()) { + return errors::InvalidArgument( + "Duplicate host_compute_core entry for cluster ", parts[0]); + } + (*host_compute_core)[parts[0]] = core; + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h index ad07624729f0b0d2443b2fc43d32dfa3377ce115..f22ddb2f58e1fa5c10ca0fdb956d9136942388b7 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.h +++ b/tensorflow/compiler/tf2xla/side_effect_util.h @@ -42,6 +42,12 @@ std::set CalculateTokenInputsForOutputToken(const Graph& g); // Returns whether a graph contains side-effecting nodes. bool HasSideEffectingNodes(const Graph& g); +// Parse the mapping from outside_compilation_subgraph name to core number, +// which is specified in an attr as a list of strings +// :. +Status ParseHostComputeCoreList(absl::Span list_from_attr, + std::map* host_compute_core); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/test_util.cc b/tensorflow/compiler/tf2xla/test_util.cc index f31bfb45a2f4db270446eb59259969dc0ab63a8e..3c6c9a91b6d2fb47f6dee1c347e9b852f1eea3ec 100644 --- a/tensorflow/compiler/tf2xla/test_util.cc +++ b/tensorflow/compiler/tf2xla/test_util.cc @@ -40,12 +40,4 @@ Status InstantiateFunctionForTest(const string& name, return Status::OK(); } -std::unordered_map BuildNodeIndex(const Graph& graph) { - std::unordered_map index; - for (Node* node : graph.nodes()) { - index[node->name()] = node; - } - return index; -} - } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/test_util.h b/tensorflow/compiler/tf2xla/test_util.h index 350a868568531c0d073e0cf600327d1ff9d62e3a..4ffc94ae3bc7c930720cd625a7856443c77be666 100644 --- a/tensorflow/compiler/tf2xla/test_util.h +++ b/tensorflow/compiler/tf2xla/test_util.h @@ -44,9 +44,6 @@ Status InstantiateFunctionForTest(const string& name, const FunctionLibraryDefinition& library, InstantiationResultForTest* result); -// Builds a map from node name to Node* for `graph`. -std::unordered_map BuildNodeIndex(const Graph& graph); - } // namespace tensorflow // Variant of TF_EXPECT_GRAPH_EQ that also compares internal attributes for diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 01dd3ba10fec85e6b1d411fbd32fbf9c58b5fe11..cc83db0562dd4ef1ae7b7a718a8f2e407acbfa1e 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -330,8 +330,8 @@ uint32 GetXLARandomSeed() { // TODO(b/77601805): add tests for associated function related stuff. bool HasAssociatedFunction(const NodeDef& node_def, - FunctionLibraryRuntime* flr) { - if (flr->GetFunctionLibraryDefinition()->Contains(node_def.op())) { + const FunctionLibraryDefinition* fld) { + if (fld->Contains(node_def.op())) { return true; } @@ -351,10 +351,10 @@ bool HasAssociatedFunction(const NodeDef& node_def, } std::vector GetAssociatedFunctions( - const Node& node, FunctionLibraryRuntime* flr) { + const Node& node, const FunctionLibraryDefinition* fld) { std::vector results; const string& op = node.type_string(); - if (flr->GetFunctionLibraryDefinition()->Contains(op)) { + if (fld->Contains(op)) { // This is a function call node. AttrValueMap attrs(node.attrs().begin(), node.attrs().end()); results.emplace_back(AssociatedFunctionInfo::FunctionCall(op, attrs)); @@ -441,4 +441,28 @@ Status RewriteAssociatedFunction( return Status::OK(); } +Status CachedFunctionHandles::GetOrInstantiate( + const string& func_name, AttrSlice attrs, + FunctionLibraryRuntime::Handle* handle) { + string canonicalized_name = Canonicalize(func_name, attrs); + auto iter = handles_.find(canonicalized_name); + if (iter != handles_.end()) { + *handle = iter->second; + return Status::OK(); + } + + TF_RETURN_IF_ERROR(flr_->Instantiate(func_name, attrs, handle)); + handles_[canonicalized_name] = *handle; + return Status::OK(); +} + +Status CachedFunctionHandles::ReleaseAllHandles() { + Status result; + for (auto iter : handles_) { + result.Update(flr_->ReleaseHandle(iter.second)); + } + handles_.clear(); + return result; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index 53eab8b63e2fc8aa3dfb0bacfe065897ca775bd0..b974b998229982afc9168dcaf0799cfddd965a04 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -120,7 +120,7 @@ class AssociatedFunctionInfo { // Returns if the NodeDef has associated function. bool HasAssociatedFunction(const NodeDef& node_def, - FunctionLibraryRuntime* flr); + const FunctionLibraryDefinition* fld); // Gets functions associated with the node. Current cases: // 1. For function call node, its function name; @@ -128,7 +128,7 @@ bool HasAssociatedFunction(const NodeDef& node_def, // and returned attrs will be this node's attributes; // 3. For nodes like XlaWhile/XlaIf, all their function attributes. std::vector GetAssociatedFunctions( - const Node& node, FunctionLibraryRuntime* flr); + const Node& node, const FunctionLibraryDefinition* fld); // Changes associated functions for the node. Current cases: // 1. For function call node, creates a new node with the new function name and @@ -144,6 +144,30 @@ Status RewriteAssociatedFunction( // Attribute to mark nodes to be executed on host. extern const char kXlaOutsideCompilationAttrName[]; +// Class to act as cache for FunctionLibraryRuntime::Handle objects. +class CachedFunctionHandles { + public: + CachedFunctionHandles(FunctionLibraryRuntime* flr) : flr_(flr) {} + + // Populates `handle` for requested function and attributes. If we have + // instantiated the function with the same attributes before, `handle` will be + // cached handle; otherwise instantiate the function and populate `handle`. + Status GetOrInstantiate(const string& func_name, AttrSlice attrs, + FunctionLibraryRuntime::Handle* handle); + + // Releases all handles in the cache. Returns first non-OK status if any; + // returns OK otherwise. + Status ReleaseAllHandles(); + + ~CachedFunctionHandles() { ReleaseAllHandles().IgnoreError(); } + + private: + FunctionLibraryRuntime* flr_; + std::map handles_; + + TF_DISALLOW_COPY_AND_ASSIGN(CachedFunctionHandles); +}; + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index 68441b3d4790b17bd06accff3fcdc8ccee79bbb7..202e929315cacd4d6cdfc69d50639d8a427ec6c2 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -23,11 +23,15 @@ limitations under the License. #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "tensorflow/core/common_runtime/graph_optimizer.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/version.h" namespace tensorflow { namespace { @@ -255,5 +259,75 @@ TEST(SetNodeShardingFromNeighbors, Basic) { EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0)); } +REGISTER_OP("One") + .Output("y: T") + .Attr("T: {float, double, int32, int64}") + .Doc(R"doc( +Returns a tensor with a single element (1) of type T. + +y: A scalar in type T. + +)doc"); + +// Tests that CachedFunctionHandles class works. +TEST(CachedFunctionHandles, Basic) { + FunctionDef func = FunctionDefHelper::Define( + // Name + "TestFunc", + // Args + {}, + // Return values + {"y:T"}, + // Attr def + {"T:{float, double, int32, int64}"}, + // Nodes + { + {{"y"}, "One", {}, {{"T", "$T"}}}, + }); + FunctionDefLibrary proto; + *proto.add_function() = func; + FunctionLibraryDefinition fld(OpRegistry::Global(), proto); + std::unique_ptr pflr( + new ProcessFunctionLibraryRuntime( + /*device_mgr=*/nullptr, Env::Default(), TF_GRAPH_DEF_VERSION, &fld, + OptimizerOptions())); + FunctionLibraryRuntime* flr = + pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + + CachedFunctionHandles cached_function_handles(flr); + + // Tests that GetOrInstantiate() works. + FunctionLibraryRuntime::Handle first_handle; + AttrValue attr; + attr.set_type(DT_FLOAT); + AttrValueMap attrs; + attrs["T"] = attr; + TF_ASSERT_OK(cached_function_handles.GetOrInstantiate( + "TestFunc", AttrSlice(&attrs), &first_handle)); + + // Tests that we can get FunctionBody. + const FunctionBody* body = flr->GetFunctionBody(first_handle); + EXPECT_NE(body, nullptr); + + // Tests that GetOrInstantiate() returns cached handle when called with same + // function name and attributes. + FunctionLibraryRuntime::Handle second_handle; + TF_ASSERT_OK(cached_function_handles.GetOrInstantiate( + "TestFunc", AttrSlice(&attrs), &second_handle)); + EXPECT_EQ(first_handle, second_handle); + + // Tests that GetOrInstantiate() returns new handle when called with same + // function name but different attributes. + attr.set_type(DT_INT32); + attrs["T"] = attr; + FunctionLibraryRuntime::Handle third_handle; + TF_ASSERT_OK(cached_function_handles.GetOrInstantiate( + "TestFunc", AttrSlice(&attrs), &third_handle)); + EXPECT_NE(first_handle, third_handle); + + // Tests that ReleaseAllHandles() works. + TF_EXPECT_OK(cached_function_handles.ReleaseAllHandles()); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index d5094e8ec5ed95b8cdbad63762a7fbc718ba5f30..b2c57e88803e0661a9a514f844dff97ff9edf2ea 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -194,6 +194,17 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, std::unique_ptr graph = GetGraph(fbody); + // Clear the "_kernel" attribute if it is set to "host". This is used to + // indicate that a computation should happen on the host instead of the + // accelerator, but doesn't make sense in XLA. + const char* const kKernelAttr = "_kernel"; + for (Node* n : graph->nodes()) { + string value; + if (GetNodeAttrSimple(n->attrs(), kKernelAttr, &value) && value == "host") { + n->ClearAttr(kKernelAttr); + } + } + // _Arg and _Retval nodes don't exist in the stored subgraph for the function; // they are added by the function body looked up. Therefore, they don't have // core assignments here. diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index f247570d72c0287a33695de3d778cce2a2418921..2095a6b8099f48a867ec2c7c7d6e84d8f2426dce 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -138,7 +138,8 @@ Status XlaContext::CreateResource( const std::set& tensor_array_gradients, XlaResource** resource) { resources_.emplace_back( new XlaResource(kind, arg_num, std::move(name), type, std::move(shape), - handle, tensor_array_size, tensor_array_gradients)); + handle, tensor_array_size, tensor_array_gradients, + /*tensor_array_multiple_writes_aggregate=*/false)); *resource = resources_.back().get(); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 2a9eaeee146bf6d792e010df7e041f9986b2c77e..dd3498ef7aa242d3ad946cae5f60bc2c8853a342 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -455,23 +455,43 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, return Status::OK(); } +Status XlaOpKernelContext::allocate_output(int index, const xla::Shape& shape, + Tensor** output) { + // The step's default allocator is the dummy XlaCompilationAllocator which + // simply allocates a metadata buffer to hold the expression to which it + // corresponds. + if (expected_output_dtype(index) == DT_VARIANT) { + // tensor_data() is not supported for variant Tensor (i.e., + // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the + // XlaExpression inside the Tensor's tensor_data() does not work for + // variant. Instead construct a uint8 tensor and store the expression in its + // value. + // TODO(jpienaar): This should be refactored to stop masquerading + // XlaExpressions as Tensors. + *output = new Tensor(); + TensorShape tensor_shape; + TF_RETURN_IF_ERROR( + context_->allocate_temp(DT_UINT8, tensor_shape, *output)); + context_->set_output(index, **output); + } else { + TensorShape tensor_shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &tensor_shape)); + TF_RETURN_IF_ERROR(context_->allocate_output(index, tensor_shape, output)); + } + return Status::OK(); +} + void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { // Makes the host Tensor that will refer to the expression. Tensor* output = nullptr; - auto shape = builder()->GetShape(handle); - if (!shape.ok()) { - SetStatus(shape.status()); + auto shape_or = builder()->GetShape(handle); + if (!shape_or.ok()) { + SetStatus(shape_or.status()); return; } - // The step's default allocator is the dummy XlaCompilationAllocator which - // simply allocates a metadata buffer to hold the expression to which it - // corresponds. - TensorShape tensor_shape; - OP_REQUIRES_OK(context_, - XLAShapeToTensorShape(shape.ValueOrDie(), &tensor_shape)); OP_REQUIRES_OK(context_, - context_->allocate_output(index, tensor_shape, &output)); + allocate_output(index, shape_or.ValueOrDie(), &output)); // The expression is stored in the tensor's data buffer. Fill in the // fields now. diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index a3a0d10cc06cd4afceec728b7dbe287389099b9d..aa00a454968ad29495e34dc080e55b62bb0b5f7b 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -255,6 +255,11 @@ class XlaOpKernelContext { // Returns the tensor of input `name`. const Tensor& GetInputTensorByName(absl::string_view name); + // Wraps OpKernelContext's allocate_output method while providing special + // behavior for DT_VARIANT: a variant is treated as DT_UINT8 scalar as the + // type to allow mapping for variant to more generic types. + Status allocate_output(int index, const xla::Shape& shape, Tensor** output); + OpKernelContext* const context_; }; diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index 56c2e01055665954b99ea635e56666fbd8b96026..63b09c8f02a60e91576544d13227d29f56d3e88c 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -29,7 +29,8 @@ namespace tensorflow { XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, TensorShape shape, const xla::XlaOp& initial_value, int64 tensor_array_size, - const std::set& tensor_array_gradients) + const std::set& tensor_array_gradients, + bool tensor_array_multiple_writes_aggregate) : kind_(kind), arg_num_(arg_num), name_(std::move(name)), @@ -37,14 +38,17 @@ XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, shape_(std::move(shape)), value_(initial_value), initial_value_(initial_value), - tensor_array_size_(tensor_array_size) { + tensor_array_size_(tensor_array_size), + tensor_array_multiple_writes_aggregate_( + tensor_array_multiple_writes_aggregate) { CHECK(kind_ != kInvalid); for (const string& gradient : tensor_array_gradients) { tensor_array_gradients_[gradient].reset(new XlaResource( /*kind=*/kTensorArray, /*arg_num=*/-1, /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_, - xla::XlaOp(), tensor_array_size_, /*tensor_array_gradients=*/{})); + xla::XlaOp(), tensor_array_size_, /*tensor_array_gradients=*/{}, + /*tensor_array_multiple_writes_aggregate=*/true)); } } @@ -137,7 +141,8 @@ Status XlaResource::GetOrCreateTensorArrayGradient(const string& source, new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_, gradient_value, tensor_array_size_, - /*tensor_array_gradients=*/{})); + /*tensor_array_gradients=*/{}, + /*tensor_array_multiple_writes_aggregate=*/true)); } *gradient_out = gradient.get(); return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index 2438490be13809b9f3571a362900b44cb838e76b..aa9ce1b171f11ea0de4db0123098729c1c97f93a 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -39,7 +39,8 @@ class XlaResource { XlaResource(Kind kind, int arg_num, string name, DataType type, TensorShape shape, const xla::XlaOp& initial_value, int64 tensor_array_size, - const std::set& tensor_array_gradients); + const std::set& tensor_array_gradients, + bool tensor_array_multiple_writes_aggregate); XlaResource(const XlaResource&) = delete; XlaResource(XlaResource&&) = delete; @@ -113,6 +114,8 @@ class XlaResource { const xla::XlaOp& pack, xla::XlaBuilder* builder); // TensorArray and Stack specific fields + // TODO(phawkins): refactor this code to use subclasses, rather than putting + // kind-specific fields in XlaResource. // 'tensor_array_size' stores the expected size of the TensorArray or Stack. // We need to store this since sometimes TensorArrays must be initialized @@ -121,6 +124,10 @@ class XlaResource { int64 tensor_array_size() const { return tensor_array_size_; } void set_tensor_array_size(int64 size) { tensor_array_size_ = size; } + bool tensor_array_multiple_writes_aggregate() const { + return tensor_array_multiple_writes_aggregate_; + } + // 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes // to an XlaResource containing the gradient TensorArrays. We store a pointer // here since there should only be one gradient TensorArray per 'source' @@ -143,6 +150,7 @@ class XlaResource { xla::XlaOp initial_value_; int64 tensor_array_size_ = -1; + bool tensor_array_multiple_writes_aggregate_ = false; std::map> tensor_array_gradients_; }; diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index f825f67b447514a416f3a49ac8aad9dcf505f5a7..dc097f3696e22d75d7dc72ec4877a9c8b5dda059 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -220,6 +220,8 @@ cc_library( "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc index a904be259a3870a679b2c4699ec01e2a11b1ce46..0475fd9c94f6e390b5169cfe2cbba8eae28ddc18 100644 --- a/tensorflow/compiler/xla/client/lib/sorting.cc +++ b/tensorflow/compiler/xla/client/lib/sorting.cc @@ -29,7 +29,7 @@ XlaOp TopK(XlaOp input, int64 k) { auto input_dims = input_shape.dimensions(); std::vector broadcast_dims(input_dims.begin(), input_dims.end() - 1); XlaOp broadcast_s32 = Broadcast(iota_s32, broadcast_dims); - XlaOp sort_result = Sort(Neg(input), broadcast_s32); + XlaOp sort_result = Sort(Neg(input), {broadcast_s32}); std::vector start_indices(input_shape.dimensions_size(), 0); std::vector limit_indices(input_dims.begin(), input_dims.end()); limit_indices[last_dim] = k; diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index ff0ec76a7f9b62fce0f14beae688cb0dd74847a1..a44681f586278bf03f3fb2b8c812936cbf3ad47b 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -93,9 +93,9 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, std::vector> MakeFakeArgumentsOrDie( const XlaComputation& computation, Client* client) { - CHECK(computation.proto().has_program_shape()) + CHECK(computation.proto().has_host_program_shape()) << "Computation should have progran shape."; - auto program_shape = computation.proto().program_shape(); + auto program_shape = computation.proto().host_program_shape(); std::vector> results; for (const Shape& shape : program_shape.parameters()) { diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 5277de6a85026573691b8337524c86142139cc4a..7d081b27222bd31ddbe7c64b4dea8a4d5a371acb 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/mutex.h" namespace xla { @@ -208,6 +208,9 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, case HloOpcode::kWhile: // TODO(b/32495713): We aren't checking the condition and body // computations themselves. + case HloOpcode::kScatter: + // TODO(b/32495713): We aren't checking the embedded computation in + // Scatter. case HloOpcode::kSend: case HloOpcode::kRecv: case HloOpcode::kParameter: @@ -275,7 +278,7 @@ StatusOr XlaBuilder::Build(int64 root_id) { module->set_id(entry.id()); module->set_entry_computation_name(entry.name()); module->set_entry_computation_id(entry.id()); - *module->mutable_program_shape() = entry.program_shape(); + *module->mutable_host_program_shape() = entry.program_shape(); for (auto& e : embedded_) { module->add_computations()->Swap(&e.second); } @@ -1276,9 +1279,10 @@ XlaOp XlaBuilder::AfterAll(absl::Span tokens) { }); } -XlaOp XlaBuilder::CustomCall(const string& call_target_name, - absl::Span operands, - const Shape& shape, const string& opaque) { +XlaOp XlaBuilder::CustomCall( + const string& call_target_name, absl::Span operands, + const Shape& shape, const string& opaque, + absl::optional> operand_shapes_with_layout) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; if (absl::StartsWith(call_target_name, "$")) { @@ -1290,6 +1294,31 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name, *instr.mutable_shape() = shape; instr.set_custom_call_target(call_target_name); instr.set_custom_call_opaque(opaque); + if (operand_shapes_with_layout.has_value()) { + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument( + "Result shape must have layout for custom call with constrained " + "layout."); + } + if (operands.size() != operand_shapes_with_layout->size()) { + return InvalidArgument( + "Must specify a shape with layout for each operand for custom call " + "with constrained layout; given %d shapes, expected %d", + operand_shapes_with_layout->size(), operands.size()); + } + instr.set_constrain_layout(true); + int64 operand_num = 0; + for (const Shape& operand_shape : *operand_shapes_with_layout) { + if (!LayoutUtil::HasLayout(operand_shape)) { + return InvalidArgument( + "No layout specified for operand %d for custom call with " + "constrained layout.", + operand_num); + } + *instr.add_operand_shapes_with_layout() = operand_shape; + ++operand_num; + } + } return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands); }); } @@ -1465,18 +1494,17 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand, }); } -XlaOp XlaBuilder::Sort(XlaOp keys, absl::optional values, +XlaOp XlaBuilder::Sort(const XlaOp& keys, absl::Span values, int64 dimension) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys)); operand_shape_ptrs.push_back(&keys_shape); - Shape values_shape; - if (values.has_value()) { - TF_ASSIGN_OR_RETURN(values_shape, GetShape(*values)); - operand_shape_ptrs.push_back(&values_shape); - } + TF_ASSIGN_OR_RETURN(std::vector values_shapes, + GetOperandShapes(values)); + absl::c_transform(values_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferVariadicOpShape( HloOpcode::kSort, operand_shape_ptrs)); @@ -1485,10 +1513,9 @@ XlaOp XlaBuilder::Sort(XlaOp keys, absl::optional values, dimension = ShapeUtil::Rank(keys_shape) - 1; } instr.add_dimensions(dimension); - return values.has_value() - ? AddInstruction(std::move(instr), HloOpcode::kSort, - {keys, *values}) - : AddInstruction(std::move(instr), HloOpcode::kSort, {keys}); + std::vector operands{keys}; + operands.insert(operands.end(), values.begin(), values.end()); + return AddInstruction(std::move(instr), HloOpcode::kSort, operands); }); } @@ -1786,9 +1813,9 @@ XlaOp XlaBuilder::ReduceWindow(const XlaOp& operand, const XlaOp& init_value, std::vector> padding_values = MakePadding(AsInt64Slice(operand_shape.dimensions()), window_dimensions, window_strides, padding); - return ReduceWindowWithGeneralPadding(operand, init_value, computation, - window_dimensions, window_strides, - padding_values); + return ReduceWindowWithGeneralPadding( + operand, init_value, computation, window_dimensions, window_strides, + /*base_dilations=*/{}, /*window_dilations=*/{}, padding_values); }); } @@ -1797,6 +1824,8 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, absl::Span> padding) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1807,7 +1836,8 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( computation.GetProgramShape()); TF_ASSIGN_OR_RETURN(*instr.mutable_window(), MakeWindow(window_dimensions, window_strides, padding, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{})); + /*lhs_dilation=*/base_dilations, + /*rhs_dilation=*/window_dilations)); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferReduceWindowShape(operand_shape, init_shape, @@ -2290,7 +2320,7 @@ StatusOr XlaBuilder::BuildConstantSubGraph( // also a valid dependency order). The related ops will be added to the // subgraph in the same order. std::set related_ops; - tensorflow::gtl::FlatSet related_calls; // Related computations. + absl::flat_hash_set related_calls; // Related computations. std::queue worklist; worklist.push(root->id()); related_ops.insert(root->id()); @@ -2327,7 +2357,7 @@ StatusOr XlaBuilder::BuildConstantSubGraph( module->set_id(entry.id()); module->set_entry_computation_name(entry.name()); module->set_entry_computation_id(entry.id()); - *module->mutable_program_shape() = *program_shape; + *module->mutable_host_program_shape() = *program_shape; for (auto& e : embedded_) { if (related_calls.find(e.second.id()) != related_calls.end()) { *module->add_computations() = e.second; @@ -2684,7 +2714,16 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, absl::Span operands, const Shape& shape, const string& opaque) { - return builder->CustomCall(call_target_name, operands, shape, opaque); + return builder->CustomCall(call_target_name, operands, shape, opaque, + /*operand_shapes_with_layout=*/absl::nullopt); +} + +XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name, + absl::Span operands, const Shape& shape, + absl::Span operand_shapes_with_layout, + const string& opaque) { + return builder->CustomCall(call_target_name, operands, shape, opaque, + operand_shapes_with_layout); } XlaOp Complex(const XlaOp& real, const XlaOp& imag, @@ -2797,10 +2836,12 @@ XlaOp ReduceWindowWithGeneralPadding( const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, absl::Span> padding) { return operand.builder()->ReduceWindowWithGeneralPadding( operand, init_value, computation, window_dimensions, window_strides, - padding); + base_dilations, window_dilations, padding); } XlaOp CrossReplicaSum(const XlaOp& operand, @@ -2911,8 +2952,8 @@ XlaOp Rev(const XlaOp& operand, absl::Span dimensions) { return operand.builder()->Rev(operand, dimensions); } -XlaOp Sort(XlaOp keys, absl::optional values, int64 dimension) { - return keys.builder()->Sort(keys, std::move(values), dimension); +XlaOp Sort(const XlaOp& keys, absl::Span values, int64 dimension) { + return keys.builder()->Sort(keys, values, dimension); } XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) { diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 1da6ddd318505eccb9c3f0d007adb785fce8ad08..5747661c34b411bbf22575f9c1d9fe09aa32911f 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/padding.h" @@ -34,8 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stacktrace.h" #include "tensorflow/core/platform/types.h" @@ -577,9 +577,10 @@ class XlaBuilder { absl::Span operands); // Enqueues a custom call instruction onto the computation. - XlaOp CustomCall(const string& call_target_name, - absl::Span operands, const Shape& shape, - const string& opaque); + XlaOp CustomCall( + const string& call_target_name, absl::Span operands, + const Shape& shape_with_layout, const string& opaque, + absl::optional> operand_shapes_with_layout); // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one @@ -671,6 +672,8 @@ class XlaBuilder { const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, absl::Span> padding); // Returns the sum of the operand value within each subgroup of replicas. All @@ -696,7 +699,7 @@ class XlaBuilder { // the same channel_id, they will be 'Allreduce'd. If empty, Allreduce will // not be applied cross modules. // - // TODO(b/79737069): Rename this to AllReduce when it's ready to use. + // TODO(b/117564385): Rename this to AllReduce when it's ready to use. XlaOp CrossReplicaSum( const XlaOp& operand, const XlaComputation& computation, absl::Span replica_groups = {}, @@ -831,12 +834,12 @@ class XlaBuilder { // the last dimension is chosen by default. // // If both keys and values are provided: - // * The keys and the values must tensors with the same dimensions. The + // * The keys and all values must be tensors with the same dimensions. The // element types of the tensors may be different. // * The result is a tuple that consists of a sorted tensor of keys (along the - // provided dimension, as above) as the first element, and a tensor with their - // corresponding values as the second element. - XlaOp Sort(XlaOp keys, absl::optional values = absl::nullopt, + // provided dimension, as above) as the first element, and tensors with their + // corresponding values as the other elements. + XlaOp Sort(const XlaOp& keys, absl::Span values = {}, int64 dimension = -1); // Enqueues a clamp instruction onto the computation. @@ -1027,7 +1030,7 @@ class XlaBuilder { // A map from XlaOp::Handle to the index in the instructions_ vector where the // instruction is held. - tensorflow::gtl::FlatMap handle_to_index_; + absl::flat_hash_map handle_to_index_; // The embedded computations used by this computation. Each computation was // the entry computation of some XlaComputation, the key is the unique id of @@ -1035,7 +1038,7 @@ class XlaBuilder { std::map embedded_; // The unique parameter numbers. - tensorflow::gtl::FlatSet parameter_numbers_; + absl::flat_hash_set parameter_numbers_; // The metadata to attach to each op. This is structured as a "modal"-like // operation, in order to simplify client code (and not sprinkle this metadata @@ -1195,6 +1198,10 @@ class XlaBuilder { friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, absl::Span operands, const Shape& shape, const string& opaque); + friend XlaOp CustomCallWithLayout( + XlaBuilder* builder, const string& call_target_name, + absl::Span operands, const Shape& shape_with_layout, + absl::Span operand_shapes_with_layout, const string& opaque); friend XlaOp Complex(const XlaOp& real, const XlaOp& imag, absl::Span broadcast_dimensions); friend XlaOp Conj(const XlaOp& operand); @@ -1245,6 +1252,8 @@ class XlaBuilder { const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, absl::Span> padding); friend XlaOp CrossReplicaSum(const XlaOp& operand, absl::Span replica_groups); @@ -1302,7 +1311,8 @@ class XlaBuilder { friend XlaOp Transpose(const XlaOp& operand, absl::Span permutation); friend XlaOp Rev(const XlaOp& operand, absl::Span dimensions); - friend XlaOp Sort(XlaOp keys, absl::optional values, int64 dimension); + friend XlaOp Sort(const XlaOp& keys, absl::Span values, + int64 dimension); friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); friend XlaOp Map(XlaBuilder* builder, absl::Span operands, const XlaComputation& computation, @@ -1728,6 +1738,17 @@ XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, absl::Span operands, const Shape& shape, const string& opaque = ""); +// Overload which constructs a custom call with fixed layouts. The operands will +// have the layouts specified by |operand_shapes_with_layout| when provided to +// external code, and the external code is expected to produce a result with the +// layout specified by |shape_with_layout|. All shapes in |shape_with_layout| +// and |operand_shapes_with_layout| must have layouts. +XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name, + absl::Span operands, + const Shape& shape_with_layout, + absl::Span operand_shapes_with_layout, + const string& opaque = ""); + // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one // of the operands is a scalar, or an explicit broadcast dimension is given @@ -1818,6 +1839,8 @@ XlaOp ReduceWindowWithGeneralPadding( const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, absl::Span> padding); // Returns the sum of the operand value within each subgroup of replicas. All @@ -1842,7 +1865,7 @@ XlaOp CrossReplicaSum(const XlaOp& operand, // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be // applied cross modules. // -// TODO(b/79737069): Rename this to AllReduce when it's ready to use. +// TODO(b/117564385): Rename this to AllReduce when it's ready to use. XlaOp CrossReplicaSum( const XlaOp& operand, const XlaComputation& computation, absl::Span replica_groups = {}, @@ -1980,12 +2003,12 @@ XlaOp Rev(const XlaOp& operand, absl::Span dimensions); // the last dimension is chosen by default. // // If both keys and values are provided: -// * The keys and the values must tensors with the same dimensions. The +// * The keys and all values must be tensors with the same dimensions. The // element types of the tensors may be different. // * The result is a tuple that consists of a sorted tensor of keys (along the -// provided dimension, as above) as the first element, and a tensor with their -// corresponding values as the second element. -XlaOp Sort(XlaOp keys, absl::optional values = absl::nullopt, +// provided dimension, as above) as the first element, and tensors with their +// corresponding values as the other elements. +XlaOp Sort(const XlaOp& keys, absl::Span values = {}, int64 dimension = -1); // Enqueues a clamp instruction onto the computation. diff --git a/tensorflow/compiler/xla/client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_computation.cc index 22c9e83bb2ae9e3e205bdd480b64c703e31c6ffd..c9870b65b91c1ebd7d44143faf215a2d5c2a2fc5 100644 --- a/tensorflow/compiler/xla/client/xla_computation.cc +++ b/tensorflow/compiler/xla/client/xla_computation.cc @@ -24,8 +24,8 @@ limitations under the License. namespace xla { StatusOr XlaComputation::GetProgramShape() const { - TF_RET_CHECK(proto_.has_program_shape()); - return proto_.program_shape(); + TF_RET_CHECK(proto_.has_host_program_shape()); + return proto_.host_program_shape(); } StatusOr> XlaComputation::Snapshot() const { diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index d310335618ded7b581e6ed632223218585bb791f..66af644cf78f3cc3ebecfaba67cf7d023b0360d5 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -65,6 +65,12 @@ void SetDefaultLayoutToContainer( return layout; } +/* static */ Layout LayoutUtil::MakeDescendingLayout(int64 rank) { + std::vector layout(rank); + std::iota(layout.rbegin(), layout.rend(), static_cast(0)); + return MakeLayout(layout); +} + /* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor( absl::Span major_to_minor) { Layout layout; @@ -156,18 +162,23 @@ Layout CreateDefaultLayoutForRank(int64 rank) { LayoutUtil::SetToDefaultLayout(program_shape->mutable_result()); } -/* static */ Status LayoutUtil::ValidateLayoutInShape(const Shape& shape) { +/* static */ Status LayoutUtil::ValidateLayoutInShape( + const Shape& shape, bool allow_missing_layouts) { if (ShapeUtil::IsTuple(shape)) { // Tuple shape. if (shape.has_layout()) { return InvalidArgument("tuple should not have a layout field"); } for (auto& element_shape : shape.tuple_shapes()) { - TF_RETURN_IF_ERROR(ValidateLayoutInShape(element_shape)); + TF_RETURN_IF_ERROR( + ValidateLayoutInShape(element_shape, allow_missing_layouts)); } return Status::OK(); } else if (ShapeUtil::IsArray(shape)) { if (!shape.has_layout()) { + if (allow_missing_layouts) { + return Status::OK(); + } return InvalidArgument("shape %s does not have a layout", ShapeUtil::HumanString(shape)); } @@ -199,10 +210,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return Status::OK(); } - if (layout.format() == INVALID_FORMAT) { + if (layout.format() == INVALID_FORMAT || !Format_IsValid(layout.format())) { return InvalidArgument( - "Layout does not have a valid format: layout {%s}, shape {%s}", - layout.ShortDebugString(), shape.ShortDebugString()); + "Layout has an invalid format (%d) in layout {%s}, shape {%s}", + layout.format(), layout.ShortDebugString(), shape.ShortDebugString()); } if (layout.format() == DENSE) { diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index b78883c2d870043032306637730c4666665125a8..97806d7e3311141920551a17d56d8ae9a1fe4af9 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -40,6 +40,10 @@ class LayoutUtil { static Layout MakeLayoutFromMajorToMinor( absl::Span major_to_minor); + // Returns a layout with descending ((i.e. {n, n-1, ..., 0}) minor-to-major + // dimensions. + static Layout MakeDescendingLayout(int64 rank); + // Creates a sparse layout with the given maximum number of elements. (This is // a convenience function for protobuf construction.) static Layout MakeSparseLayout(int64 max_sparse_elements); @@ -64,8 +68,11 @@ class LayoutUtil { // default. static void SetToDefaultLayout(ProgramShape* program_shape); - // Validates that the layout within the given shape is correct. - static Status ValidateLayoutInShape(const Shape& shape); + // Validates that the layout within the given shape is correct. The check + // is performed for all subshapes as well. If missing layouts are allowed + // the check does not fail on array shapes without layouts. + static Status ValidateLayoutInShape(const Shape& shape, + bool allow_missing_layouts = false); // Validates that the provided layout satisfies invariants for the given // shape. diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index f25dae6ff411133c74502039f441060f1329ffd4..a50d53eaeb15daa9f7a98a816e180d3a55568bb8 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -352,5 +352,92 @@ TEST_F(LayoutUtilTest, StreamOut) { EXPECT_EQ(oss.str(), "{0,1,2}"); } +TEST_F(LayoutUtilTest, ValidateLayout_ValidArrayLayout) { + Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {0, 1}); + auto status = + LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false); + EXPECT_TRUE(status.ok()); + status = + LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true); + EXPECT_TRUE(status.ok()); +} + +TEST_F(LayoutUtilTest, ValidateLayout_InvalidArrayLayout) { + Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); + *shape.mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2}); + auto status = + LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false); + EXPECT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("layout minor_to_major field " + "contains 3 elements, but shape is rank 2")); + status = + LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true); + EXPECT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("layout minor_to_major field " + "contains 3 elements, but shape is rank 2")); +} + +TEST_F(LayoutUtilTest, ValidateLayout_MissingArrayLayout) { + Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); + LayoutUtil::ClearLayout(&shape); + auto status = + LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false); + EXPECT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("shape f32[2,3] does not have a layout")); + status = + LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true); + EXPECT_TRUE(status.ok()); +} + +TEST_F(LayoutUtilTest, ValidateLayout_TupleWithLayout) { + Shape shape = ShapeUtil::MakeTupleShape({}); + *shape.mutable_layout() = LayoutUtil::MakeLayout({0}); + auto status = + LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false); + EXPECT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("tuple should not have a layout field")); + status = + LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true); + EXPECT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("tuple should not have a layout field")); +} + +TEST_F(LayoutUtilTest, ValidateLayout_TupleSubshapesWithMissingLayouts) { + Shape sub_1_1_1 = ShapeUtil::MakeShape(F32, {1, 2}); + Shape sub_1_1 = ShapeUtil::MakeTupleShape({sub_1_1_1}); + Shape sub_1_2 = ShapeUtil::MakeShape(F32, {1, 2}); + LayoutUtil::ClearLayout(&sub_1_2); + Shape sub_1 = ShapeUtil::MakeTupleShape({sub_1_1, sub_1_2}); + Shape sub_2_1 = ShapeUtil::MakeShape(F32, {9}); + LayoutUtil::ClearLayout(&sub_2_1); + Shape sub_2 = ShapeUtil::MakeTupleShape({sub_2_1}); + Shape shape = ShapeUtil::MakeTupleShape({sub_1, sub_2}); + + auto status = + LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false); + EXPECT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("shape f32[1,2] does not have a layout")); + status = + LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true); + EXPECT_TRUE(status.ok()); + + // Add invalid layout on one of sub-shapes. + *shape.mutable_tuple_shapes(1)->mutable_tuple_shapes(0)->mutable_layout() = + LayoutUtil::MakeLayout({0, 2, 3}); + + status = + LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true); + EXPECT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("layout minor_to_major field " + "contains 3 elements, but shape is rank 1")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 5035f4198890857fcafd0156d7eaeeb4bc164322..656ce720a13d5c9622e9dc05ae04ddcac8cbeee5 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -287,6 +287,8 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, return InvalidArgument("LiteralProto has no layout"); } + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape())); + Literal literal(proto.shape()); TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus( @@ -725,16 +727,34 @@ Literal LiteralBase::Slice(absl::Span start_indices, ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions, LayoutUtil::MinorToMajor(shape())); switch (result_shape.element_type()) { - case F32: - return SliceInternal(result_shape, start_indices); + case PRED: + return SliceInternal(result_shape, start_indices); + case U8: + return SliceInternal(result_shape, start_indices); + case U16: + return SliceInternal(result_shape, start_indices); + case U32: + return SliceInternal(result_shape, start_indices); + case U64: + return SliceInternal(result_shape, start_indices); + case S8: + return SliceInternal(result_shape, start_indices); + case S16: + return SliceInternal(result_shape, start_indices); + case S32: + return SliceInternal(result_shape, start_indices); + case S64: + return SliceInternal(result_shape, start_indices); + case F16: + return SliceInternal(result_shape, start_indices); case BF16: return SliceInternal(result_shape, start_indices); + case F32: + return SliceInternal(result_shape, start_indices); + case F64: + return SliceInternal(result_shape, start_indices); case C64: return SliceInternal(result_shape, start_indices); - case S32: - return SliceInternal(result_shape, start_indices); - case U32: - return SliceInternal(result_shape, start_indices); default: LOG(FATAL) << "not yet implemented: " << PrimitiveType_Name(result_shape.element_type()); @@ -1850,6 +1870,24 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape())); + if (LayoutUtil::IsSparseArray(subshape())) { + // Compute the number of elements (indices) in the sparse shape and reserve + // the necessary space in spare_indices. + TF_RET_CHECK(ShapeUtil::Rank(subshape()) != 0) + << "Scalar shapes cannot be sparse"; + TF_RET_CHECK(proto.sparse_indices_size() % ShapeUtil::Rank(subshape()) == 0) + << "Unexpected number of indices in proto (" + << proto.sparse_indices_size() << ") for shape of rank " + << ShapeUtil::Rank(subshape()); + const int64 index_count = + proto.sparse_indices_size() / ShapeUtil::Rank(subshape()); + sparse_indices()->Resize(index_count); + + // Copy the indices from the proto into the SparseIndexArray object. + TF_RETURN_IF_ERROR(CopyFromRepeatedField(sparse_indices()->mutable_data(), + proto.sparse_indices())); + } + switch (subshape().element_type()) { case PRED: TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.preds())); @@ -1907,11 +1945,11 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { } } break; case TUPLE: - LOG(FATAL) << "Should not be called on tuple shapes: " - << ShapeUtil::HumanString(subshape()); - break; + return InvalidArgument("Should not be called on tuple shapes: %s", + ShapeUtil::HumanString(subshape())); default: - LOG(FATAL) << "Unhandled primitive type " << subshape().element_type(); + return InvalidArgument("Is called on unsupported shape: %s", + ShapeUtil::HumanString(subshape())); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index 7ad287c8973367fb04583e6911ff75e76bdf5f1e..dd5b54e4c99998f676419cf98a3da16593338829 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -224,6 +224,16 @@ TEST_F(LiteralUtilTest, CreateSparse) { absl::Span(expected_indices.data(), expected_indices.num_elements())); EXPECT_EQ(literal.data(), absl::Span(expected_values)); + + // Serialize then deserialize and verify the resulting literal. + TF_ASSERT_OK_AND_ASSIGN(Literal literal_from_proto, + Literal::CreateFromProto(literal.ToProto())); + + EXPECT_EQ(literal_from_proto.sparse_indices()->data(), + absl::Span(expected_indices.data(), + expected_indices.num_elements())); + EXPECT_EQ(literal_from_proto.data(), + absl::Span(expected_values)); } TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index cd5fd330298fb0ff158e232dac121f8ffb271218..92df404b8ec0aed4899906877a4dd41102bdf7a0 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -532,10 +532,13 @@ LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( const LocalComputation& local_computation, absl::Span window_dimensions, absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, absl::Span> padding) { return xla::ReduceWindowWithGeneralPadding( operand.op(), init_value.op(), local_computation.computation(), - window_dimensions, window_strides, padding); + window_dimensions, window_strides, base_dilations, window_dilations, + padding); } LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu, @@ -569,13 +572,13 @@ StatusOr LocalComputationBuilder::IsConstant(const LocalOp& operand) { } LocalOp LocalComputationBuilder::Sort(const LocalOp& operand, int64 dimension) { - return xla::Sort(operand.op(), absl::nullopt, dimension); + return xla::Sort(operand.op(), {}, dimension); } LocalOp LocalComputationBuilder::SortKeyVal(const LocalOp& keys, const LocalOp& values, int64 dimension) { - return xla::Sort(keys.op(), values.op(), dimension); + return xla::Sort(keys.op(), {values.op()}, dimension); } StatusOr LocalComputationBuilder::BuildConstantSubGraph( diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 2166bb6721ca380f3180a8802e4922f2e9e45945..43332e0abd410c08dc5a40f7de39dbc96d34a72c 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -278,6 +278,8 @@ class LocalComputationBuilder { const LocalComputation& local_computation, absl::Span window_dimensions, absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, absl::Span > padding); LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma, diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index bb303c5678a2cac9a9e78925e857ab25c0c6d9be..f8197488fb3bacb312cc7fbf149b773851992b8a 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -995,7 +995,30 @@ class ComputationBuilder(object): window_strides) return self._client.ReduceWindowWithGeneralPadding( operand, init_value, computation_to_apply.c_local_computation, - window_dimensions, window_strides, pads) + window_dimensions, window_strides, (), (), pads) + + def ReduceWindowWithGeneralPadding( + self, operand, init_value, computation_to_apply, window_dimensions, + window_strides, base_dilations, window_dilations, padding): + """Enqueues a windowed reduction operation onto the computation. + + Args: + operand: reduction operand (LocalOp). + init_value: reduction initial value (LocalOp). + computation_to_apply: a binary reduction function (Computation). + window_dimensions: dimensions of window (sequence of integers). + window_strides: strides for window (sequence of integers). + base_dilations: dilations for the base (sequence of integers). + window_dilations: dilations for window (sequence of integers). + padding: length-N array-like of pairs of integers of (low, high) padding. + + Returns: + A LocalOp representing the added ReduceWindow op. + """ + return self._client.ReduceWindowWithGeneralPadding( + operand, init_value, computation_to_apply.c_local_computation, + window_dimensions, window_strides, base_dilations, window_dilations, + padding) def RngNormal(self, mu, sigma, dims): """Enqueues an RngNormal operation onto the computation. diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index e800cf470cfd129f93c2a1be586e03bebcaec987..3a716c385b2bced7b36e65012d2ff6888525524a 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -146,6 +146,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -182,6 +184,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -251,6 +254,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -290,12 +294,14 @@ cc_library( srcs = [ "dfs_hlo_visitor.cc", "hlo_computation.cc", + "hlo_input_output_alias_config.cc", "hlo_instruction.cc", "hlo_instructions.cc", "hlo_module.cc", "hlo_opcode.cc", "hlo_schedule.cc", "hlo_sharding.cc", + "hlo_sharding_metadata.cc", ], hdrs = [ "dfs_hlo_visitor.h", @@ -303,12 +309,14 @@ cc_library( "hlo_clone_context.h", "hlo_computation.h", "hlo_domain_metadata.h", + "hlo_input_output_alias_config.h", "hlo_instruction.h", "hlo_instructions.h", "hlo_module.h", "hlo_opcode.h", "hlo_schedule.h", "hlo_sharding.h", + "hlo_sharding_metadata.h", ], deps = [ ":hlo_casting_utils", @@ -333,6 +341,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -395,6 +405,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:span", ], ) @@ -485,6 +496,8 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -776,6 +789,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -846,6 +860,7 @@ cc_library( ":executable", ":hlo", ":hlo_module_config", + ":hlo_module_group", ":logical_buffer", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -903,6 +918,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -952,6 +968,8 @@ cc_library( deps = [ "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -987,6 +1005,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ], @@ -1034,6 +1054,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -1087,6 +1109,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ], @@ -1125,6 +1148,8 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", ], ) @@ -1146,6 +1171,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", ], ) @@ -1196,6 +1222,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", ], @@ -1216,6 +1243,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -1242,6 +1271,25 @@ tf_cc_test( ], ) +tf_cc_test( + name = "hlo_input_output_alias_config_test", + srcs = ["hlo_input_output_alias_config_test.cc"], + deps = [ + ":hlo", + ":hlo_dce", + ":hlo_memory_scheduler", + ":hlo_ordering", + ":hlo_parser", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", + ], +) + cc_library( name = "hlo_memory_scheduler", srcs = ["hlo_memory_scheduler.cc"], @@ -1260,6 +1308,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -1280,6 +1330,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -1294,16 +1345,26 @@ cc_library( ], ) +cc_library( + name = "fusion_queue", + hdrs = ["fusion_queue.h"], + deps = [ + ":hlo", + ], +) + cc_library( name = "instruction_fusion", srcs = ["instruction_fusion.cc"], hdrs = ["instruction_fusion.h"], deps = [ + ":fusion_queue", ":hlo", ":hlo_pass", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", ], ) @@ -1330,6 +1391,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -1385,6 +1448,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], @@ -1640,6 +1704,8 @@ cc_library( ":while_loop_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], @@ -1671,6 +1737,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -1796,42 +1863,6 @@ tf_cc_test( ], ) -cc_library( - name = "inliner", - srcs = ["inliner.cc"], - hdrs = ["inliner.h"], - deps = [ - ":hlo", - ":hlo_pass", - ":hlo_query", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:lib", - "@com_google_absl//absl/types:span", - ], -) - -tf_cc_test( - name = "inliner_test", - srcs = ["inliner_test.cc"], - deps = [ - ":cpu_plugin", - ":hlo", - ":hlo_matchers", - ":inliner", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "@com_google_absl//absl/memory", - ], -) - cc_library( name = "computation_placer", srcs = ["computation_placer.cc"], @@ -2043,6 +2074,7 @@ cc_library( ":logical_buffer", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -2078,6 +2110,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -2099,6 +2132,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -2182,6 +2216,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -2203,6 +2238,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], @@ -2263,6 +2300,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -2319,6 +2358,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -2345,6 +2386,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -2416,6 +2459,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -2428,6 +2472,7 @@ tf_cc_test( ":hlo", ":hlo_parser", ":hlo_verifier", + ":layout_assignment", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", @@ -2460,6 +2505,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -2588,6 +2635,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -2627,6 +2676,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", ], ) @@ -2701,26 +2751,12 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", ], ) -cc_library( - name = "hlo_sharding_metadata", - srcs = ["hlo_sharding_metadata.cc"], - hdrs = [ - "hlo_sharding_metadata.h", - ], - deps = [ - ":hlo", - "//tensorflow/compiler/xla:shape_tree", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/core:lib", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/types:span", - ], -) - cc_library( name = "hlo_domain_verifier", srcs = ["hlo_domain_verifier.cc"], @@ -2771,7 +2807,6 @@ tf_cc_test( ":hlo_domain_isolator", ":hlo_domain_remover", ":hlo_parser", - ":hlo_sharding_metadata", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -3114,6 +3149,7 @@ cc_library( ":buffer_assignment", ":hlo", ":hlo_proto", + ":hlo_verifier", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:util", ], @@ -3147,6 +3183,7 @@ cc_library( ":hlo_pass_pipeline", "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -3175,6 +3212,7 @@ cc_library( ":computation_placer", ":executable", ":hlo", + ":hlo_module_group", ":transfer_manager", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -3269,6 +3307,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", ], ) @@ -3298,6 +3338,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", ], ) @@ -3354,6 +3395,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:ptr_util", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -3381,7 +3424,6 @@ cc_library( deps = [ ":hlo", ":hlo_lexer", - ":hlo_sharding_metadata", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -3439,6 +3481,39 @@ cc_library( deps = ["//tensorflow/core:lib"], ) +cc_library( + name = "map_inliner", + srcs = ["map_inliner.cc"], + hdrs = ["map_inliner.h"], + deps = [ + ":hlo", + ":hlo_pass", + ":hlo_query", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "map_inliner_test", + srcs = ["map_inliner_test.cc"], + deps = [ + ":hlo", + ":hlo_matchers", + ":map_inliner", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/memory", + ], +) + tf_cc_test( name = "hlo_casting_utils_test", srcs = ["hlo_casting_utils_test.cc"], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 75dae7a7141647d7b7b60b0e07e11c143621ea63..72ed5ca48217298cab6dc63b1f2dd30a0730817d 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -157,6 +157,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) override; + Status HandleSelect(HloInstruction* select) override; + Status HandleSort(HloInstruction* sort) override; Status HandleTranspose(HloInstruction* transpose) override; @@ -2057,6 +2059,12 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( return Status::OK(); } + // Bail on dilation. + if (window_util::HasDilation(window)) { + VLOG(10) << "Not folding pad into reduce-window as there is dilation."; + return Status::OK(); + } + VLOG(10) << "Considering folding Pad: " << pad->ToString() << "\ninto reduce-window: " << reduce_window->ToString() << (convert != nullptr @@ -2193,6 +2201,22 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( /*reduce_computation=*/function)); } +Status AlgebraicSimplifierVisitor::HandleSelect(HloInstruction* select) { + // select(x, y, y) -> y. + if (select->operand(1) == select->operand(2)) { + return ReplaceInstruction(select, select->mutable_operand(1)); + } + // select(true, x, y) -> x. + if (IsAll(select->operand(0), true)) { + return ReplaceInstruction(select, select->mutable_operand(1)); + } + // select(false, x, y) -> y. + if (IsAll(select->operand(0), false)) { + return ReplaceInstruction(select, select->mutable_operand(2)); + } + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { auto operand = sort->mutable_operand(0); int64 dimension_to_sort = sort->dimensions(0); @@ -2203,7 +2227,7 @@ Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { } // If it is key/value sort, the output of sort is a tuple. return ReplaceWithNewInstruction( - sort, HloInstruction::CreateTuple({operand, sort->mutable_operand(1)})); + sort, HloInstruction::CreateTuple(sort->operands())); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 2047f894b465816eb97ba205e79033bd52bf7a0c..c79c518700b63be2fda8a415b38fe246689ab7c6 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -97,6 +97,73 @@ TEST_F(AlgebraicSimplifierTest, MulZero) { EXPECT_EQ(computation->root_instruction(), zero); } +// Test that select(true, a, b) is simplified to a +TEST_F(AlgebraicSimplifierTest, SelectTrue) { + Shape r0s32 = ShapeUtil::MakeShape(S32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0s32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0s32, "param1")); + HloInstruction* one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + builder.AddInstruction(HloInstruction::CreateTernary( + r0s32, HloOpcode::kSelect, one, param0, param1)); + + auto module = CreateNewVerifiedModule(); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kSelect); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_EQ(computation->root_instruction(), param0); +} + +// Test that select(false, a, b) is simplified to b +TEST_F(AlgebraicSimplifierTest, SelectFalse) { + Shape r0s32 = ShapeUtil::MakeShape(S32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0s32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0s32, "param1")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateTernary( + r0s32, HloOpcode::kSelect, zero, param0, param1)); + + auto module = CreateNewVerifiedModule(); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kSelect); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_EQ(computation->root_instruction(), param1); +} + +// Test that select(a, b, b) is simplified to b +TEST_F(AlgebraicSimplifierTest, SelectIdentical) { + Shape r0s32 = ShapeUtil::MakeShape(S32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0s32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0s32, "param1")); + builder.AddInstruction(HloInstruction::CreateTernary( + r0s32, HloOpcode::kSelect, param0, param1, param1)); + + auto module = CreateNewVerifiedModule(); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kSelect); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_EQ(computation->root_instruction(), param1); +} + // Test that Reduce(Reduce(A)) -> Reduce(A) TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { HloComputation::Builder builder(TestName()); @@ -2133,16 +2200,20 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { Shape values_shape = ShapeUtil::MakeShape(S32, {5, 0}); auto keys = builder.AddInstruction( HloInstruction::CreateParameter(0, keys_shape, "keys")); - auto values = builder.AddInstruction( - HloInstruction::CreateParameter(1, values_shape, "values")); + auto values0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, values_shape, "values0")); + auto values1 = builder.AddInstruction( + HloInstruction::CreateParameter(2, values_shape, "values1")); builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + ShapeUtil::MakeTupleShape({keys_shape, values_shape, values_shape}), 0, + keys, {values0, values1})); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Tuple(keys, values)); + EXPECT_THAT(computation->root_instruction(), + op::Tuple(keys, values0, values1)); } // Used for TEST_Ps that test merging (or not) of a kPad instruction into a diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index 1ed6142dcecdc830cb7b8386e0cc20a2ea54aa7f..ef5e211646e7b0b66b8e6c09948be58063422943 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -176,13 +176,13 @@ StatusOr> AllocationTracker::DeconstructTuple( } StatusOr> AllocationTracker::Resolve( - const GlobalDataHandle& data) { + const GlobalDataHandle& data) const { tensorflow::mutex_lock lock(mutex_); return AllocationTracker::ResolveInternal(data); } StatusOr AllocationTracker::ResolveForReplica( - const GlobalDataHandle& data, int replica_id) { + const GlobalDataHandle& data, int replica_id) const { tensorflow::mutex_lock lock(mutex_); TF_ASSIGN_OR_RETURN(std::vector replicated_buffers, ResolveInternal(data)); @@ -196,7 +196,7 @@ StatusOr AllocationTracker::ResolveForReplica( } StatusOr> AllocationTracker::ResolveInternal( - const GlobalDataHandle& data) { + const GlobalDataHandle& data) const { VLOG(2) << "resolve:" << data.handle(); auto it = handle_to_shaped_buffers_.find(data.handle()); if (it == handle_to_shaped_buffers_.end()) { diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index a7d8927cf7e90d764ff8046df16c71922b11478e..98d1a302a9f66f4a00e05d62837a79133e222687 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -64,13 +65,13 @@ class AllocationTracker { // replica, or provide an error status to say whether any of those buffers // were not found (or found, but found deallocated). StatusOr> Resolve( - const GlobalDataHandle& data); + const GlobalDataHandle& data) const; // Resolves a handle from an XLA client and replica id to a shaped buffer, or // provide an error status to say whether it was not found (or found, but // found deallocated). StatusOr ResolveForReplica(const GlobalDataHandle& data, - int replica_id); + int replica_id) const; private: // Data structure encapsulating single memory allocation on the device. @@ -86,7 +87,7 @@ class AllocationTracker { // Internal helper which resolves the given GlobalDataHandle to a // list of ScopedShapedBuffers. StatusOr> ResolveInternal( - const GlobalDataHandle& data) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + const GlobalDataHandle& data) const EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Internal helper which registers a vector of shaped buffers, one per // replica. ShapedBufferTy is either ScopedShapedBuffer or ShapedBuffer. If @@ -110,9 +111,9 @@ class AllocationTracker { // A map from device memory opaque value to allocation. One such map is // maintained per device ordinal. - using AllocationMap = tensorflow::gtl::FlatMap; + using AllocationMap = absl::flat_hash_map; - tensorflow::mutex mutex_; + mutable tensorflow::mutex mutex_; // Backend to use with this tracker. The backend supplies the memory allocator // to use when deallocating memory. @@ -123,10 +124,7 @@ class AllocationTracker { int64 next_handle_ GUARDED_BY(mutex_); // A map from device ordinal to AllocationMap. - // - // This is not a TF FlatMap because (currently) FlatMap (and therefore - // AllocationMap) is not movable. - std::unordered_map opaque_to_allocation_map_ + absl::flat_hash_map opaque_to_allocation_map_ GUARDED_BY(mutex_); // A map from data handle to a vector of shaped buffers that represent the @@ -146,7 +144,7 @@ class AllocationTracker { // non-owning "view" into a tuple's sub-buffers. The sub-buffers are then // free'd when both the view *and* the original tuple are Unregistered. This // refcounting is managed in opaque_to_allocation_map_. - tensorflow::gtl::FlatMap>> + absl::flat_hash_map>> handle_to_shaped_buffers_ GUARDED_BY(mutex_); TF_DISALLOW_COPY_AND_ASSIGN(AllocationTracker); diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index 30d33e0d3531bb5e931ebfa0b60c91523dd0cb44..f70f6ddfec69c0113a1afe2073a2392098f49456 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index d5b1148058898596bfdb837826a590bbc74e202a..1251f0258f5d43a490ad654f519fee9076590453 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -231,6 +231,10 @@ Status BFloat16NormalizationVisitor::HandleMultipleOutputs( for (auto* user : materialized_users) { TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, tuple)); } + bool is_root = computation_->root_instruction() == hlo; + if (is_root) { + computation_->set_root_instruction(tuple); + } *tuple->mutable_shape() = original_shape; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index cef0eba14e9dd463d6c32b047211bf25a84478f6..cb075a5e38a5ea9db2ceb432b2b59f8db5e2e640 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -284,7 +284,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { HloInstruction::CreateParameter(1, s32_shape, "value")); HloInstruction* sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), 0, key, value)); + ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), 0, key, {value})); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0)); @@ -298,6 +298,30 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { EXPECT_EQ(ShapeUtil::GetSubshape(sort->shape(), {0}).element_type(), F32); } +TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSortRoot) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {1024}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {1024}); + + HloInstruction* key = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "key")); + HloInstruction* value = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "value")); + + HloInstruction* sort = builder.AddInstruction(HloInstruction::CreateSort( + ShapeUtil::MakeTupleShape({bf16_shape, bf16_shape}), 0, key, {value})); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module)); + + EXPECT_EQ(sort->operand(0)->shape().element_type(), F32); + EXPECT_EQ(ShapeUtil::GetSubshape(sort->shape(), {0}).element_type(), F32); + EXPECT_NE(computation->root_instruction(), sort); + EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kTuple); +} + // Tests that the normalization should not cause unsupported mixed precision due // to resolving unsupported BF16 operand. TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) { diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 58f78f8e24d0bc00a63e3583828cf8e01ae4531a..002be9c97098ef1f73446c458dae24bbc826a626 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_propagation.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -81,7 +82,7 @@ void BFloat16Propagation::RevertIfFusionInternalBF16Changes( }; auto root = fusion->fused_instructions_computation()->root_instruction(); - tensorflow::gtl::FlatSet changed_root_buffers; + absl::flat_hash_set changed_root_buffers; auto root_changes_it = changes_to_bf16_.find(root); if (root_changes_it != changes_to_bf16_.end()) { @@ -500,7 +501,7 @@ void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) { bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( HloComputation* computation, - tensorflow::gtl::FlatSet* visited_computations) { + absl::flat_hash_set* visited_computations) { bool parameter_changed = false; auto insts = computation->MakeInstructionPostOrder(); // Do the adjustment on each instruction in the computation in reverse @@ -560,7 +561,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( // another input parameter. A fixed point will be reached because the // parameters can only be changed from BF16 to F32, not the other way // around. - tensorflow::gtl::FlatSet visited_in_while; + absl::flat_hash_set visited_in_while; while (ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_condition(), &visited_in_while) || ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(), @@ -587,7 +588,7 @@ void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( HloModule* module) { const auto& computations_topological_order = module->MakeComputationPostOrder(); - tensorflow::gtl::FlatSet resolved; + absl::flat_hash_set resolved; for (auto comp_it = computations_topological_order.rbegin(); comp_it != computations_topological_order.rend(); ++comp_it) { if (ContainsKey(resolved, *comp_it)) { diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h index 6a62439f8877634a065979d1e2fcda262ca83dc1..5fcaa15c8356107af02e9099874a293d8350c51a 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.h +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/bfloat16_support.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -81,7 +83,7 @@ class BFloat16Propagation : public HloModulePass { // The set of instructions to consider using bfloat16, computed in the forward // pass. - tensorflow::gtl::FlatSet consider_using_bfloat16_; + absl::flat_hash_set consider_using_bfloat16_; // *************************** // Functions called and state produced by the backward pass (from root to @@ -110,12 +112,12 @@ class BFloat16Propagation : public HloModulePass { // The set of HloInstructions that have been visited in the // opportunity-finding pass. - tensorflow::gtl::FlatSet + absl::flat_hash_set instructions_visited_in_backward_pass_; // The set of HloComputations that have been visited in the // opportunity-finding pass. - tensorflow::gtl::FlatSet + absl::flat_hash_set computations_visited_in_backward_pass_; // *************************** @@ -131,7 +133,7 @@ class BFloat16Propagation : public HloModulePass { // point is reached. bool ResolveInconsistencyOfAliasingBuffersHelper( HloComputation* computation, - tensorflow::gtl::FlatSet* visited_computations); + absl::flat_hash_set* visited_computations); // Makes the parameters of called computations match how they are called by // the given HLO. @@ -182,11 +184,11 @@ class BFloat16Propagation : public HloModulePass { PrimitiveType target_type); // The set of F32 HLO values that must be kept in F32. - tensorflow::gtl::FlatSet values_that_must_be_kept_as_f32_; + absl::flat_hash_set values_that_must_be_kept_as_f32_; // Mapping from each HloComputation to the number of callers to it in the // module. Populated at the beginning of this pass. - tensorflow::gtl::FlatMap caller_counts_; + absl::flat_hash_map caller_counts_; // We first store the potential F32-to-BF16 changes to changes_to_bf16_, which // are subject to further adjustment, then finally applied to the HLOs. This @@ -195,8 +197,7 @@ class BFloat16Propagation : public HloModulePass { // // For each HloInstruction, changes_to_bf16_ stores the affected buffers in // the output as a map from in-place pointers to subshapes to shape indices. - tensorflow::gtl::FlatMap> + absl::flat_hash_map> changes_to_bf16_; // Whether the last processed HLO module has been changed by this pass. diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc index 23645346e6f491beb5171cc839c013ce5f83d789..5b48f10505e78c035608d4c575501e4623218987 100644 --- a/tensorflow/compiler/xla/service/bfloat16_support.cc +++ b/tensorflow/compiler/xla/service/bfloat16_support.cc @@ -78,8 +78,10 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( const HloInstruction& hlo, int64 operand_index) { switch (hlo.opcode()) { case HloOpcode::kAbs: + case HloOpcode::kAllToAll: case HloOpcode::kBroadcast: case HloOpcode::kClamp: + case HloOpcode::kCollectivePermute: case HloOpcode::kConcatenate: case HloOpcode::kConvert: case HloOpcode::kCopy: diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 34a7be0e9c079e9e42c28eef10af4079e99853b6..d5d6a044a81303425495202d8a98c6735b0b8b89 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -41,10 +43,10 @@ limitations under the License. namespace xla { namespace { +using absl::flat_hash_map; +using absl::flat_hash_set; using absl::StrAppend; using absl::StrAppendFormat; -using ::tensorflow::gtl::FlatMap; -using ::tensorflow::gtl::FlatSet; using ::tensorflow::strings::HumanReadableNumBytes; template @@ -128,8 +130,8 @@ Status GatherComputationsByAllocationType( // Sets for quickly checking membership. Computations are returned in vectors // for stable iteration. - FlatSet thread_local_set; - FlatSet global_set; + flat_hash_set thread_local_set; + flat_hash_set global_set; while (!worklist.empty()) { auto worklist_front = worklist.front(); @@ -237,7 +239,7 @@ BufferAllocation::Slice BufferAllocation::GetSlice( void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size) { - VLOG(4) << "Trying to add " << buffer << " to " << this; + VLOG(4) << "Trying to add " << buffer << " to allocation #" << index(); CHECK(assigned_buffers_.count(&buffer) == 0) << "LogicalBuffer " << buffer << " already assigned to allocation " << index_; @@ -444,7 +446,7 @@ bool BufferAssignment::SharesSliceAtIndex( bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a, const HloInstruction* hlo_b) const { using SliceSet = - FlatSet; + flat_hash_set; // Gets the slices all of instr's subshapes. If any subshape doesn't have an // assigned slice, returns the empty set. auto collect_slices = [&](const HloInstruction* instr) -> SliceSet { @@ -519,7 +521,8 @@ void BufferAssignment::AddAssignment(BufferAllocation* allocation, // BufferAllocation. void BufferAssignment::CombineTempAllocations() { VLOG(1) << "CombineTempAllocations()"; - FlatMap + flat_hash_map combined_allocation_map; // Move all temp allocations into a single run at the end of the allocations @@ -582,7 +585,8 @@ void BufferAssignment::CombineTempAllocations() { } // Update allocation indices to their new positions. - allocation_index_for_buffer_.clear_no_resize(); + allocation_index_for_buffer_.erase(allocation_index_for_buffer_.begin(), + allocation_index_for_buffer_.end()); for (size_t index = 0; index < allocations_.size(); ++index) { BufferAllocation* allocation = &allocations_[index]; allocation->set_index(index); @@ -780,21 +784,6 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, } } - if (allow_input_output_aliasing_ && allocation->maybe_live_out()) { - const HloComputation* entry_computation = - assignment->module_->entry_computation(); - for (auto param : entry_computation->parameter_instructions()) { - for (auto& param_buffer : - assignment->points_to_analysis().GetBuffersDefinedByInstruction( - param)) { - if (assignment->liveness().MayInterfere(*param_buffer, buffer)) { - VLOG(4) << "Can't assign: Parameter interference with result"; - return false; - } - } - } - } - // If the buffer is live out of the computation then it should only be // assigned a buffer which exactly fits the result to avoid wasting memory // (result buffers can have arbitrary lifetimes). @@ -812,9 +801,9 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, Status BufferAssigner::AssignBuffersForComputation( const HloComputation* computation, bool is_thread_local, - const FlatSet& colocated_buffers, - const FlatSet& colocated_allocations, - FlatMap>* + const flat_hash_set& colocated_buffers, + const flat_hash_set& colocated_allocations, + flat_hash_map>* buffers_to_assign_sequentially, BufferAssignment* assignment) { // Buffers are sorted and assigned to BufferAllocations in decreasing order of @@ -833,7 +822,7 @@ Status BufferAssigner::AssignBuffersForComputation( // Generate a post order sort of instructions for sorting of the // LogicalBuffers. - FlatMap post_order_position; + flat_hash_map post_order_position; int position = 0; for (auto* instruction : computation->MakeInstructionPostOrder()) { post_order_position.emplace(instruction, position); @@ -850,8 +839,8 @@ Status BufferAssigner::AssignBuffersForComputation( // buffers_to_assign_sequentially map, even if we end up with an empty set // of buffers. This ensures we can correctly determine whether to run // whole-module heap simulation. - buffers_to_assign_sequentially->emplace(computation, - FlatSet()); + buffers_to_assign_sequentially->emplace( + computation, flat_hash_set()); } // Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers @@ -1043,12 +1032,12 @@ Status BufferAssigner::AssignBuffersForComputation( return Status::OK(); } -FlatMap, - LogicalBuffer::Color::Hasher> +flat_hash_map, + LogicalBuffer::Color::Hasher> BufferAssigner::SplitBuffersByColor( - const FlatSet& buffers) { - FlatMap, - LogicalBuffer::Color::Hasher> + const flat_hash_set& buffers) { + flat_hash_map, + LogicalBuffer::Color::Hasher> color_map; for (auto buffer : buffers) { color_map[buffer->color()].insert(buffer); @@ -1057,7 +1046,8 @@ BufferAssigner::SplitBuffersByColor( } Status BufferAssigner::AssignBuffersWithSequentialOrdering( - const FlatMap>& + const flat_hash_map>& buffers_to_assign_sequentially, bool run_whole_module_heap_simulation, BufferAssignment* assignment) { // Run the sequence of instructions through the heap simulator. The heuristic @@ -1083,10 +1073,11 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( // only live for the duration of their calling instructions. VLOG(1) << "Running whole-module heap simulation"; HloSchedule schedule(&assignment->module()); - FlatSet all_buffers_to_assign; + flat_hash_set all_buffers_to_assign; for (const auto& pair : buffers_to_assign_sequentially) { const HloComputation* computation = pair.first; - const FlatSet& buffers_to_assign = pair.second; + const flat_hash_set& buffers_to_assign = + pair.second; const std::vector* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); @@ -1120,7 +1111,8 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( VLOG(1) << "Running per-computation heap simulation"; for (const auto& pair : buffers_to_assign_sequentially) { const HloComputation* computation = pair.first; - const FlatSet& buffers_to_assign = pair.second; + const flat_hash_set& buffers_to_assign = + pair.second; const std::vector* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); @@ -1155,9 +1147,8 @@ std::vector ComputePeakMemoryLogicalBuffers( const BufferAllocation& allocation, const HeapSimulatorTrace& heap_trace) { // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical // buffers in this allocation. - tensorflow::gtl::FlatMap - id_to_buffer; - tensorflow::gtl::FlatMap buffer_sizes; + absl::flat_hash_map id_to_buffer; + absl::flat_hash_map buffer_sizes; for (const auto& pair : allocation.assigned_buffers()) { const LogicalBuffer* buffer = pair.first; const BufferAllocation::OffsetSize& offset_size = pair.second; @@ -1196,7 +1187,7 @@ std::vector ComputePeakMemoryLogicalBuffers( // Next gather the set of logical buffers live at the earliest point of // maximal live set size. - tensorflow::gtl::FlatSet live_buffers; + absl::flat_hash_set live_buffers; live_size = 0; for (const auto& event : heap_trace.events()) { const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id()); @@ -1428,13 +1419,28 @@ BufferAssigner::MergeColocatedBufferSets( // Builds sets of buffers in 'colocated_buffer_sets' which should be colocated // in the same allocation (currently just supports kWhile, kCall, and -// kConditional). +// kConditional and input output aliasing). void BufferAssigner::BuildColocatedBufferSets( const HloModule* module, const BufferLiveness& buffer_liveness, const LogicalBuffer::SizeFunction& buffer_size, std::vector* colocated_buffer_sets) { const TuplePointsToAnalysis& points_to_analysis = buffer_liveness.points_to_analysis(); + + // Set up colocated buffer set for input and output. + module->input_output_alias_config().ForEachAlias( + [&](const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index) { + std::vector colocated_set; + AddBufferToColocatedSet(module->entry_computation()->root_instruction(), + output_index, points_to_analysis, + &colocated_set); + AddBufferToColocatedSet( + module->entry_computation()->parameter_instruction(param_number), + param_index, points_to_analysis, &colocated_set); + AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); + }); + for (const HloComputation* computation : module->MakeComputationPostOrder()) { if (computation->IsFusionComputation()) { continue; @@ -1586,8 +1592,8 @@ void BufferAssigner::BuildColocatedBufferSets( void BufferAssigner::AssignColocatedBufferSets( const std::vector& colocated_buffer_sets, BufferAssignment* assignment, - FlatSet* colocated_buffers, - FlatSet* colocated_allocations) { + flat_hash_set* colocated_buffers, + flat_hash_set* colocated_allocations) { for (const ColocatedBufferSet& colocated_buffer_set : colocated_buffer_sets) { BufferAllocation* allocation = nullptr; // Set 'entry_parameter_number' and 'entry_parameter_shape_idx' if entry @@ -1660,8 +1666,8 @@ StatusOr> BufferAssigner::CreateAssignment( // Once b/32491382 enables module-level liveness analysis, we may be able // to assign colocated buffers (or at least reuse their allocation for // buffers outside of the set) in AssignBuffersForComputation. - FlatSet colocated_buffers; - FlatSet colocated_allocations; + flat_hash_set colocated_buffers; + flat_hash_set colocated_allocations; std::vector colocated_buffer_sets; BuildColocatedBufferSets(module, assignment->liveness(), assignment->buffer_size_, &colocated_buffer_sets); @@ -1679,7 +1685,7 @@ StatusOr> BufferAssigner::CreateAssignment( // First assign buffers for global computatations. Temporary buffers for // sequential computations are collected in 'buffers_to_assign_sequentially'. - FlatMap> + flat_hash_map> buffers_to_assign_sequentially; for (auto* computation : global_computations) { TF_RETURN_IF_ERROR(AssignBuffersForComputation( diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 24ba7c16f548c10f58f41d2b88488939ca2d8e4d..899cd36e1f98c9e7b8ba7e42c06ced5c3e8afcc8 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -22,6 +22,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" @@ -33,8 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -148,7 +148,7 @@ class BufferAllocation { // Access to the logical buffers assigned to this allocation, and their // associated logical offsets and sizes. - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& assigned_buffers() const { return assigned_buffers_; } @@ -323,7 +323,7 @@ class BufferAllocation { // Mapping from the set of buffers assigned to this allocation to their // logical offsets and sizes. - tensorflow::gtl::FlatMap assigned_buffers_; + absl::flat_hash_map assigned_buffers_; int64 fragmentation_bytes_ = 0; std::vector heap_traces_; @@ -500,7 +500,7 @@ class BufferAssignment { int64 temp_allocation_total_size_ = 0; // Maps Buffers to the index of the BufferAllocation which holds the buffer. - tensorflow::gtl::FlatMap + absl::flat_hash_map allocation_index_for_buffer_; const HloModule* module_; @@ -554,11 +554,10 @@ class BufferAssigner { // true. Status AssignBuffersForComputation( const HloComputation* computation, bool is_thread_local, - const tensorflow::gtl::FlatSet& colocated_buffers, - const tensorflow::gtl::FlatSet& - colocated_allocations, - tensorflow::gtl::FlatMap>* + const absl::flat_hash_set& colocated_buffers, + const absl::flat_hash_set& colocated_allocations, + absl::flat_hash_map>* buffers_to_assign_sequentially, BufferAssignment* assignment); @@ -568,9 +567,8 @@ class BufferAssigner { // 'run_whole_module_heap_simulation' is true, the heap simulation will be run // assuming all global computations are sequentially ordered. Status AssignBuffersWithSequentialOrdering( - const tensorflow::gtl::FlatMap< - const HloComputation*, - tensorflow::gtl::FlatSet>& + const absl::flat_hash_map>& buffers_to_assign_sequentially, bool run_whole_module_heap_simulation, BufferAssignment* assignment); @@ -590,7 +588,7 @@ class BufferAssigner { // alias. Explicitly handling these colocated buffers is necessary because // points-to analysis is computation level scope and does not recognize // aliasing across computations (b/32491382). - using ColocatedBufferSet = tensorflow::gtl::FlatSet; + using ColocatedBufferSet = absl::flat_hash_set; // Returns a vector of ColocatedBufferSet objects, where each // ColocatedBufferSet aggregates a set of related LogicalBuffers from 'module' @@ -605,8 +603,8 @@ class BufferAssigner { void AssignColocatedBufferSets( const std::vector& colocated_buffer_sets, BufferAssignment* assignment, - tensorflow::gtl::FlatSet* colocated_buffers, - tensorflow::gtl::FlatSet* colocated_allocations); + absl::flat_hash_set* colocated_buffers, + absl::flat_hash_set* colocated_allocations); // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining // the invariant that all sets in 'colocated_buffer_sets' are disjoint. @@ -624,11 +622,10 @@ class BufferAssigner { // Split a set of buffers into several sets, each of which contains buffers // colored with the same color. - tensorflow::gtl::FlatMap, - LogicalBuffer::Color::Hasher> - SplitBuffersByColor( - const tensorflow::gtl::FlatSet& buffers); + absl::flat_hash_map, + LogicalBuffer::Color::Hasher> + SplitBuffersByColor(const absl::flat_hash_set& buffers); // If true, buffer assignments assumes that input parameter buffers and output // buffers can be shared if their sizes match. diff --git a/tensorflow/compiler/xla/service/buffer_liveness.h b/tensorflow/compiler/xla/service/buffer_liveness.h index cdd3cf4032ef6916086e1c2d148b575192503000..f939a426ead7c34092fc5234ef779ee857347a26 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.h +++ b/tensorflow/compiler/xla/service/buffer_liveness.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" @@ -27,8 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -102,7 +101,7 @@ class BufferLiveness { // Set of LogicalBuffers which are aliased in the output of other // instructions. For example, a LogicalBuffer which is inserted into a tuple // is considered to be aliased and will be in this set. - tensorflow::gtl::FlatSet aliased_buffers_; + absl::flat_hash_set aliased_buffers_; // LogicalBuffers that may be live out of the entry computation. PointsToSet::BufferSet maybe_live_out_buffers_; diff --git a/tensorflow/compiler/xla/service/buffer_value.h b/tensorflow/compiler/xla/service/buffer_value.h index 69b36463560a1fad4f62687e9014fb3fbe5bbd13..11d8abc5badf7b1a05239ed74a05be0c899e37a1 100644 --- a/tensorflow/compiler/xla/service/buffer_value.h +++ b/tensorflow/compiler/xla/service/buffer_value.h @@ -141,6 +141,9 @@ class BufferValue { // operator< is required for std::set. bool operator<(const BufferValue& other) const { return id_ < other.id_; } + bool operator==(const BufferValue& other) const { return id_ == other.id_; } + bool operator!=(const BufferValue& other) const { return id_ != other.id_; } + virtual string ToString() const = 0; // TODO(lauj) rename LogicalBufferProto to BufferValueProto. diff --git a/tensorflow/compiler/xla/service/buffer_value_containers.h b/tensorflow/compiler/xla/service/buffer_value_containers.h index 305914fca828f110bf54239bddb1590172562b16..cc46af5eeec623e19637cd6245915b3a3124a2cd 100644 --- a/tensorflow/compiler/xla/service/buffer_value_containers.h +++ b/tensorflow/compiler/xla/service/buffer_value_containers.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/core/lib/gtl/compactptrset.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -38,7 +38,7 @@ BufferValueCompactPointerSet ToBufferValueCompactPointerSet( return output; } -using BufferValueFlatSet = tensorflow::gtl::FlatSet; +using BufferValueFlatSet = absl::flat_hash_set; template BufferValueFlatSet ToBufferValueFlatSet( const LogicalBufferContainerT& logical_buffer_container) { diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 23b2a327096dfdb3c756a4acc5476ec01dcac1b3..bdd5069632e84fe6c67ca129f726432479ac1b35 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -138,7 +139,7 @@ CallGraphNode& CallGraph::GetNode(const HloComputation* computation) { bool CallGraph::DominatesHelper( const HloComputation* a, const HloComputation* b, - tensorflow::gtl::FlatSet* visited) const { + absl::flat_hash_set* visited) const { if (a == b || ContainsKey(*visited, b)) { // The call graph is guaranteed to be acyclic so any previously visited node // we encounter was already determined to be dominated. @@ -163,7 +164,7 @@ bool CallGraph::DominatesHelper( bool CallGraph::Dominates(const HloComputation* a, const HloComputation* b) const { - tensorflow::gtl::FlatSet visited; + absl::flat_hash_set visited; return DominatesHelper(a, b, &visited); } @@ -277,7 +278,7 @@ std::unique_ptr CallGraph::Build(const HloModule* module) { Status CallGraph::VisitNodesInternal( const VisitorFunction& visitor_func, const CallGraphNode& node, - tensorflow::gtl::FlatSet* visited) const { + absl::flat_hash_set* visited) const { auto pair = visited->insert(&node); if (!pair.second) { // Node was not inserted. Node has already been visited. @@ -294,7 +295,7 @@ Status CallGraph::VisitNodesInternal( Status CallGraph::VisitNodes(const VisitorFunction& visitor_func, bool visit_unreachable_nodes) const { - tensorflow::gtl::FlatSet visited; + absl::flat_hash_set visited; if (visit_unreachable_nodes) { // Traverse from all roots in the call graph. for (const CallGraphNode& node : nodes()) { diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index 3af2ab5edfd9faf4ac5193df4b823c21b55b2f7f..cb56f4789d06ac33acdaadc8b619b9e37f683d58 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -20,11 +20,11 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -145,19 +145,19 @@ class CallGraphNode { // The computations called by this computation. The vector is used for a // stable ordering and the set enables fast membership testing. std::vector callees_; - tensorflow::gtl::FlatSet callee_set_; + absl::flat_hash_set callee_set_; // The computations which call this computation. The vector is used for a // stable ordering and the set enables fast membership testing. std::vector callers_; - tensorflow::gtl::FlatSet caller_set_; + absl::flat_hash_set caller_set_; // The call sites in this computation std::vector callsites_; // The map from instruction to index in callsites_ for looking up the callsite // (if any) associated with a particular instruction in this computation. - tensorflow::gtl::FlatMap callsite_instructions_; + absl::flat_hash_map callsite_instructions_; // The call sites in other computations which call this computation. std::vector caller_callsites_; @@ -250,14 +250,14 @@ class CallGraph { // 'visited'. Status VisitNodesInternal( const VisitorFunction& visitor_func, const CallGraphNode& node, - tensorflow::gtl::FlatSet* visited) const; + absl::flat_hash_set* visited) const; // Recursive helper for computing whether 'a' dominates 'b' in the call // graph. 'b_ancestor' is the currently visited node (which starts at 'b'), // and 'visited' is the set of computations which have been visited. bool DominatesHelper( const HloComputation* a, const HloComputation* b, - tensorflow::gtl::FlatSet* visited) const; + absl::flat_hash_set* visited) const; // The HLO module represented by this call graph. const HloModule* module_ = nullptr; @@ -267,7 +267,7 @@ class CallGraph { // Map from HLO computation to the index of the corresponding call graph node // in nodes_. - tensorflow::gtl::FlatMap node_indices_; + absl::flat_hash_map node_indices_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 96bd2616f5607de888a096f8392ceb68490276e3..6d67f970020d278cc7bf61b56350200d3e5cb926 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -67,7 +67,7 @@ CompileOnlyService::CompileAheadOfTime( std::unique_ptr* metadata) { std::vector> hlo_modules; for (const AotXlaComputationInstance& instance : computations) { - TF_RET_CHECK(instance.computation.has_program_shape()); + TF_RET_CHECK(instance.computation.has_host_program_shape()); const DebugOptions& debug_options = options.debug_options(); @@ -86,9 +86,11 @@ CompileOnlyService::CompileAheadOfTime( Executable::DumpToDirectory(per_host_path, filename, hlo_snapshot)); } - const auto& program_shape = instance.computation.program_shape(); + const auto& program_shape = instance.computation.host_program_shape(); ExecutionOptions execution_options; *execution_options.mutable_debug_options() = debug_options; + *execution_options.mutable_shape_with_output_layout() = + *instance.result_layout; TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig(program_shape, instance.argument_layouts, @@ -101,8 +103,10 @@ CompileOnlyService::CompileAheadOfTime( hlo_modules.push_back(std::move(hlo_module)); } - return compiler_->CompileAheadOfTime(std::move(hlo_modules), options, - metadata); + return compiler_->CompileAheadOfTime( + absl::make_unique(hlo_modules[0]->name(), + absl::MakeSpan(hlo_modules)), + options, metadata); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 687ecafe0c308ecc22857fae650c6998677f605d..80c630c6201503d88a690f04a88f6fca6f3a438a 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -45,7 +45,7 @@ Compiler::ComputeDefaultBackendConfig(const HloInstruction& hlo, // Define a default version where metadata is not used. StatusOr>> Compiler::CompileAheadOfTime( - std::vector> modules, + std::unique_ptr module_group, const AotCompilationOptions& options, std::unique_ptr* metadata) { if (metadata != nullptr) { @@ -53,7 +53,7 @@ Compiler::CompileAheadOfTime( "Populating AotCompilationMetadata is not implemented on this " "compiler."); } - return CompileAheadOfTime(std::move(modules), options); + return CompileAheadOfTime(std::move(module_group), options); } /* static */ std::map* diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 1fdda31c34a17a16f75e1efada542c2c2ea15038..9ab179303b3e792c1f94c08626d7bc1afd2099f8 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_module_group.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -135,6 +136,12 @@ class Compiler { std::unique_ptr module, se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) = 0; + // Optimizes a HLO module group, a set of module which runs concurrently on + // multiple devices potentially communicating data between the modules. + virtual Status RunHloPassesOnModuleGroup( + HloModuleGroup* module_group, se::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator) = 0; + // Compiles the HLO module for execution on a device given by the executor, // and returns an executable object or an error status. No HLO passes are // applied to module. Generally a module should be passed through RunHloPasses @@ -145,12 +152,18 @@ class Compiler { // (not just type of device) indicated by the executor. // // device_allocator is optional; see RunHloPasses. - // - // Use the overload below to compile computations that run in parallel. virtual StatusOr> RunBackend( std::unique_ptr module, se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) = 0; + // Compiles a set of HLO modules that can run in parallel, potentially + // communicating data between the modules. + virtual StatusOr>> + RunBackendOnModuleGroup( + std::unique_ptr module_group, + std::vector> stream_exec, + DeviceMemoryAllocator* device_allocator) = 0; + // Compiles a set of HLO modules that can run in parallel, potentially // communicating data between the modules, and returns a corresponding // sequence of executable objects. @@ -160,7 +173,7 @@ class Compiler { // TODO(b/68666782): Remove this method after adding support for multiple // modules to RunHloPasses and RunBackends. virtual StatusOr>> Compile( - std::vector> modules, + std::unique_ptr module_group, std::vector> stream_exec, DeviceMemoryAllocator* device_allocator) = 0; @@ -184,16 +197,16 @@ class Compiler { ComputeDefaultBackendConfig(const HloInstruction& hlo, se::StreamExecutor* executor) const; - // Compiles the HLO module for ahead-of-time execution. This is intended for - // use in static compilation. + // Compiles the HLO module group for ahead-of-time execution. This is + // intended for use in static compilation. virtual StatusOr>> - CompileAheadOfTime(std::vector> modules, + CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& options) = 0; // Similar to CompileAheadOfTime above but AotCompilationMetadata // has an argument that can be populated during compilation. virtual StatusOr>> - CompileAheadOfTime(std::vector> modules, + CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& options, std::unique_ptr* metadata); diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc index af8f7f1027a40703137d6880a9865449c560a47b..efc893818d03a20d6bd65b7dc1da72ea5da5ceb0 100644 --- a/tensorflow/compiler/xla/service/computation_layout.cc +++ b/tensorflow/compiler/xla/service/computation_layout.cc @@ -56,4 +56,14 @@ string ComputationLayout::ToString() const { result_layout_.ToString()); } +ProgramShape ComputationLayout::ComputeProgramShape() const { + ProgramShape program_shape; + for (int64 i = 0; i < parameter_layouts_.size(); ++i) { + *program_shape.add_parameters() = parameter_layouts_[i].shape(); + *program_shape.add_parameter_names() = absl::StrCat("p", i); + } + *program_shape.mutable_result() = result_layout_.shape(); + return program_shape; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/computation_layout.h b/tensorflow/compiler/xla/service/computation_layout.h index 6975f387b4864bf28ea0ad23d7d4602b5b346e08..a2fb656677f354fbf85ff613d826cd6be86ba3bf 100644 --- a/tensorflow/compiler/xla/service/computation_layout.h +++ b/tensorflow/compiler/xla/service/computation_layout.h @@ -83,6 +83,10 @@ class ComputationLayout { // Returns a string representation of this object. string ToString() const; + // Create a ProgramShape proto based on the parameter and result shapes held + // within this object. + ProgramShape ComputeProgramShape() const; + private: std::vector parameter_layouts_; ShapeLayout result_layout_; diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index b65dfef9c9575b683b2656af2ccc151d87db2cd7..245db6be2a400a7447f1e87317018cbb1572c405 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/copy_insertion.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" @@ -31,8 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -40,10 +40,12 @@ namespace { using absl::StrAppend; -bool IsEntryParameterValue(const HloValue& value) { +bool IsReadonlyEntryParameterValue(const HloValue& value) { const HloComputation* computation = value.defining_instruction()->parent(); return value.defining_instruction()->opcode() == HloOpcode::kParameter && - computation == computation->parent()->entry_computation(); + computation == computation->parent()->entry_computation() && + !computation->parent()->input_output_alias_config().ParameterHasAlias( + value.defining_instruction()->parameter_number(), value.index()); } bool IsConstantValue(const HloValue& value) { @@ -51,7 +53,7 @@ bool IsConstantValue(const HloValue& value) { } bool ValueIsReadOnly(const HloValue& value) { - return IsConstantValue(value) || IsEntryParameterValue(value); + return IsConstantValue(value) || IsReadonlyEntryParameterValue(value); } // Data structure describing the action which should be taken on parts of a @@ -79,8 +81,7 @@ SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node, bool ShouldCopyRootValue(const HloValue& value, const SpecialCaseCopyPolicy& policy) { if (policy.copy_parameters_and_constants) { - return IsConstantValue(value) || - value.defining_instruction()->opcode() == HloOpcode::kParameter; + return ValueIsReadOnly(value); } return false; } @@ -332,6 +333,81 @@ Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, return Status::OK(); } +// Conservatively adds copies before root instruction of entry computation and +// each aliased parameter to resolve interference of aliased input and output +// buffer. We later rely on the CopyRemover to drop the unnecessary ones. +Status AddCopiesForAliasedInputOutputs(HloModule* module) { + HloComputation* entry = module->entry_computation(); + HloInstruction* root = entry->root_instruction(); + + ShapeTree output_indices_to_copy(root->shape()); + std::vector> copied_parameters; + bool has_alias = false; + for (auto* param : entry->parameter_instructions()) { + bool param_has_alias = false; + ShapeTree param_indices_to_copy(param->shape()); + + module->input_output_alias_config().ForEachAlias( + [&](const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index) { + if (param_number == param->parameter_number()) { + param_has_alias = true; + *(param_indices_to_copy.mutable_element(param_index)) = true; + *(output_indices_to_copy.mutable_element(output_index)) = true; + } + }); + + if (!param_has_alias) { + continue; + } + + has_alias = true; + // Store a snapshot of users before DeepCopyInstruction, as + // DeepCopyInstruction introduces new users of the instruction. + std::vector users = param->users(); + ShapeTree param_copy_tree(param->shape(), + /*init_value=*/nullptr); + TF_ASSIGN_OR_RETURN(HloInstruction * copied, + entry->DeepCopyInstruction( + param, ¶m_indices_to_copy, ¶m_copy_tree)); + for (HloInstruction* user : users) { + TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, copied)); + } + + copied_parameters.push_back(param_copy_tree); + } + + if (!has_alias) { + return Status::OK(); + } + + // Add copies before root instruction. + ShapeTree output_copy_tree(root->shape(), + /*init_value=*/nullptr); + + TF_ASSIGN_OR_RETURN(HloInstruction * root_copied, + root->parent()->DeepCopyInstruction( + root, &output_indices_to_copy, &output_copy_tree)); + + // Add control dependencies between the input/output copies. + TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus( + [&](const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& input_index) -> Status { + HloInstruction* from = + copied_parameters[param_number].element(input_index); + HloInstruction* to = output_copy_tree.element(output_index); + + TF_RET_CHECK(from != nullptr); + TF_RET_CHECK(to != nullptr); + TF_RETURN_IF_ERROR(from->AddControlDependencyTo(to)); + return Status::OK(); + })); + + entry->set_root_instruction(root_copied); + + return Status::OK(); +} + // Removes any control dependencies to or from the given instruction. Status StripControlDependenciesFrom(HloInstruction* instruction) { while (!instruction->control_successors().empty()) { @@ -432,7 +508,7 @@ class CopyRemover { // Construct a list for each HLO buffer in the alias analysis. Maintain a // map from HloValue to the respective list element representing that // value. The map is used to construct the copy info map below. - tensorflow::gtl::FlatMap value_to_node; + absl::flat_hash_map value_to_node; for (const HloBuffer& buffer : alias_analysis.buffers()) { // Verify values contained in the buffer are strictly ordered. This // should always be the case after adding copies to eliminate @@ -480,7 +556,7 @@ class CopyRemover { // respective ValueNode representing that value. void AddValueList( absl::Span values, - tensorflow::gtl::FlatMap* value_to_node) { + absl::flat_hash_map* value_to_node) { ValueNode* tail = nullptr; ValueNode* head = nullptr; for (const HloValue* value : values) { @@ -516,8 +592,7 @@ class CopyRemover { // respective ValueNode. void CreateCopyMap( const HloModule& module, - const tensorflow::gtl::FlatMap& - value_to_node) { + const absl::flat_hash_map& value_to_node) { for (HloComputation* computation : module.computations()) { for (HloInstruction* instruction : computation->instructions()) { // Add copies with unambiguous source values to the map. Copies with @@ -905,7 +980,7 @@ class CopyRemover { // The heads of all the value lists. Each value list represents the HLO // values contained in a particular HLO buffer. The values in the list are // in dependency order. - tensorflow::gtl::FlatSet value_lists_; + absl::flat_hash_set value_lists_; // Copy removal requires fast access to the value list elements // corresponding to the source and destination values of the kCopy @@ -916,7 +991,7 @@ class CopyRemover { ValueNode* src = nullptr; ValueNode* dest = nullptr; }; - tensorflow::gtl::FlatMap copy_map_; + absl::flat_hash_map copy_map_; }; HloModule* module_; @@ -954,6 +1029,8 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) { } } } + + TF_RETURN_IF_ERROR(AddCopiesForAliasedInputOutputs(module)); return Status::OK(); } @@ -1010,7 +1087,7 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph, HloInstruction* root = computation->root_instruction(); // Mark nondistinct/ambiguous indices. - tensorflow::gtl::FlatSet seen; + absl::flat_hash_set seen; ShapeUtil::ForEachSubshape( root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) { std::vector buffers_at_index = diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 892d0d7b547aaf1e7f1c55e4163d1e1fd9518def..4533ebb99bbba854a029fb8a9a1e31b023be720d 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -1351,6 +1351,218 @@ TEST_F(CopyInsertionTest, SwizzlingWhile) { EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy())); } +TEST_F(CopyInsertionTest, CrossingParameters) { + // Test a case where two parameters' dataflow cross with each other while + // input and output are aliased with same index: + // + // (p0 , p1) + // | \ /| + // | \ / | + // alias X alias + // | / \ | + // | / \| + // (p1 , p0) + auto module = CreateNewModule(); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0})); + module->AddEntryComputation(builder.Build()); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 4); +} + +TEST_F(CopyInsertionTest, ParametersAliasing) { + // Test a case where two parameters' dataflow don't interfere with each other + // while aliased. + // + // (p0 , p1) + // | | + // | | + // alias alias + // | | + // | | + // (p0 , p1) + auto module = CreateNewModule(); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + module->AddEntryComputation(builder.Build()); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 0); +} + +TEST_F(CopyInsertionTest, ParameterWithNoAliasing) { + // Test a case where no parameter is aliased with result. In this case, copy + // should be added + // + // (p0 , p1) + // | | + // | | + // | | + // | | + // | | + // (p0 , p1) + auto module = CreateNewModule(); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + module->AddEntryComputation(builder.Build()); + InsertCopies(module.get()); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(op::GetTupleElement(param, 0)), + op::Copy(op::GetTupleElement(param, 1)))); + + EXPECT_EQ(CountCopies(*module), 2); +} + +TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) { + // Test a case where one parameter is aliased with result while another one + // isn't. + // + // (p0 , p1) + // | | + // | | + // alias | + // | | + // | | + // (p0 , p1) + auto module = CreateNewModule(); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + module->AddEntryComputation(builder.Build()); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + InsertCopies(module.get()); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::GetTupleElement(param, 0), + op::Copy(op::GetTupleElement(param, 1)))); + + EXPECT_EQ(CountCopies(*module), 1); +} + +TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) { + // Test a case where one parameter is aliased with result while another one + // isn't. + // + // +-- (p0 , p1) + // | | | + // | | | + // alias Negate Negate + // | | | + // | | | + // +-- (p0 , p1) + auto module = CreateNewModule(); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + + auto negate0 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0)); + + auto negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1)); + builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1})); + module->AddEntryComputation(builder.Build()); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 0); +} + +TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) { + // Test a case where one parameter is aliased with result while another one + // isn't. + // + // +-- (p0 , p1) + // | | | + // | | | + // alias Negate Negate + // | | | + // | Add----+ + // | | | + // +-- (p0 , p1) + auto module = CreateNewModule(); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + + auto negate0 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0)); + + auto negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1)); + + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, negate0, negate1)); + builder.AddInstruction(HloInstruction::CreateTuple({add, negate1})); + module->AddEntryComputation(builder.Build()); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 0); +} + TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) { // Test a while instruction with a body which permutes its tuple parameter // elements and applies one operation to one of the elements. The addition of diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index b7103118ac5cbd47e060b170a8e432e2ec93c0fd..58abb330a6e31e9b7a8081cd7964cf89a5b64a09 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -94,6 +94,7 @@ cc_library( ":target_machine_features", "@com_google_absl//absl/types:span", "//tensorflow/compiler/tf2xla:cpu_function_runtime", + "//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", @@ -127,7 +128,6 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:indexed_array_analysis", - "//tensorflow/compiler/xla/service:inliner", "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", @@ -290,6 +290,8 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -309,6 +311,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@llvm//:analysis", "@llvm//:target", ], @@ -471,6 +474,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "//tensorflow/stream_executor", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", ], @@ -762,6 +766,7 @@ cc_library( "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 18fc144efe0023c0893adfcb16eda3341c0938d3..da01c0caf2a6665f71cc087270b21fffdd6caa0d 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -86,8 +86,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" -#include "tensorflow/compiler/xla/service/inliner.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/map_inliner.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/scatter_expander.h" @@ -249,9 +249,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( &pipeline, module->config().debug_options(), ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); - // TODO(b/35786417): Re-enable inliner pass after fixing the bug and deciding - // where we will take this pass in future. - // pipeline.AddPass(); + pipeline.AddPass(); // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner // pass. @@ -308,7 +306,8 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( ReducePrecisionInsertion::PassTiming::AFTER_FUSION); pipeline.AddPass( - module->mutable_entry_computation_layout(), target_machine_features); + module->mutable_entry_computation_layout(), + LayoutAssignment::InstructionCanChangeLayout, target_machine_features); return pipeline.Run(module).status(); } @@ -328,8 +327,13 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( { auto& pass = pipeline.AddPass>( "simplification after layout assignement"); - pass.AddInvariantChecker(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + // TODO(b/117156505): When the bug is fixed, the CPU backend should not + // produce layout changing elementwise operations. We will then pass + // LayoutAssignment::InstructionCanChangeLayout to the HLO verifier to + // enable stricter verification. + pass.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); pass.AddPass>( /*is_layout_sensitive=*/true, [](const Shape&, const Shape&) { return true; }, @@ -672,9 +676,12 @@ StatusOr> CpuCompiler::RunBackend( } StatusOr>> -CpuCompiler::CompileAheadOfTime(std::vector> modules, +CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& aot_options) { - TF_RET_CHECK(!modules.empty()); + TF_RET_CHECK(!module_group->empty()); + std::vector> modules = + module_group->ConsumeModules(); + std::call_once(llvm_command_line_options_initialized, &llvm_ir::InitializeLLVMCommandLineOptions, modules[0]->config()); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index f2af923782df268e3e6da3895ec35579ab6aa51f..c67307548dda731f8fa56b8e6790e7e83f587113 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -142,7 +142,7 @@ class CpuCompiler : public LLVMCompiler { DeviceMemoryAllocator* device_allocator) override; StatusOr>> - CompileAheadOfTime(std::vector> modules, + CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& options) override; se::Platform::Id PlatformId() const override; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc index be1208fb2df2a1a11a093810b5f6c2a83f468062..e6b6fcdf684eadb3702e490bbe24dbb7b3b52ec7 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc @@ -57,10 +57,13 @@ TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { HloInstruction::CreateParameter(1, sparse_shape, "param1")); builder.AddInstruction(HloInstruction::CreateBinary( sparse_shape, HloOpcode::kAdd, param0, param1)); - auto module = CreateNewModule(); + // Since verifier is reporting sparse layouts as errors, we should + // use a regular HloModule instead of VerifiedHloModule to avoid + // verifier errors being triggered in the destructor. + auto module = HloTestBase::CreateNewModule(); module->AddEntryComputation(builder.Build()); - Status status = checker().Run(module).status(); + Status status = checker().Run(module.get()).status(); ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); EXPECT_THAT(status.error_message(), HasSubstr("CPU backend does not support")); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index bfecbd6e017893e4f6d3dcbc01d46c899e6060fa..c291bf2d1ba2eaff4192051840768c037bece86f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" @@ -38,7 +39,7 @@ using absl::nullopt; using absl::optional; using ShouldMakeOperandColMajorCache = - tensorflow::gtl::FlatMap; + absl::flat_hash_map; } // namespace static bool ShouldMakeAllUsersColMajor(const HloInstruction* instruction) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h index 3c4fe68b830d9602f009b318d4e51e9a04a27e09..f4da35dd373f24d81323d198582048e2e6d36268 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h @@ -30,8 +30,11 @@ class CpuLayoutAssignment : public LayoutAssignment { public: explicit CpuLayoutAssignment( ComputationLayout* entry_computation_layout, + std::function + instruction_can_change_layout_func, const TargetMachineFeatures* target_machine_features) - : LayoutAssignment(entry_computation_layout), + : LayoutAssignment(entry_computation_layout, + std::move(instruction_can_change_layout_func)), target_machine_features_(*target_machine_features) {} ~CpuLayoutAssignment() override {} diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 4668f3872dad598edf4c7680e1b601622104ab3e..97659b88a7974d7caf91ab0d4741f3635e4dae4a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -54,8 +54,9 @@ class CpuLayoutAssignmentTest : public HloTestBase { [](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); - cpu::CpuLayoutAssignment layout_assignment(entry_computation_layout, - &target_machine_features); + cpu::CpuLayoutAssignment layout_assignment( + entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout, + &target_machine_features); EXPECT_IS_OK(layout_assignment.Run(module).status()); } }; @@ -321,8 +322,9 @@ static StatusOr RunDotOutputFusion( [](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); - cpu::CpuLayoutAssignment layout_assignment(&computation_layout, - &target_machine_features); + cpu::CpuLayoutAssignment layout_assignment( + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + &target_machine_features); TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something, layout_assignment.Run(module)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 20cf8557354b161451cf5b7825ccfce57d96875a..a9febe891b5e9d1eb9e6b297952b50d1d26a3396 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "absl/synchronization/mutex.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/core/platform/dynamic_annotations.h" @@ -30,8 +31,7 @@ namespace cpu { namespace runtime { XfeedManager* GetXfeedManager(int device_ordinal) { - static tensorflow::gtl::FlatMap* managers = - new tensorflow::gtl::FlatMap(); + static auto* managers = new absl::flat_hash_map(); static absl::Mutex* mutex = new absl::Mutex(); absl::MutexLock lock(mutex); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index c3e802078385d4724f0da26e8b6c16503e3662a1..b2abdb39a598871a7cc44760e464f48b9a200874 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -24,6 +24,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" @@ -67,8 +69,6 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -688,8 +688,25 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( for (size_t i = 0; i < index.size(); ++i) { llvm::Value* strided_index = NSWMul(index[i], b_.getInt64(window.dimensions(i).stride())); - input_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]), - b_.getInt64(window.dimensions(i).padding_low())); + input_index[i] = NSWSub( + NSWAdd(strided_index, + NSWMul(window_index[i], + b_.getInt64(window.dimensions(i).window_dilation()))), + b_.getInt64(window.dimensions(i).padding_low())); + + // We need to verify that we are not in the dilated base area. + llvm::Value* dilation_condition = ICmpEQ( + SRem(input_index[i], b_.getInt64(window.dimensions(i).base_dilation())), + b_.getInt64(0)); + if (in_bounds_condition == nullptr) { + in_bounds_condition = dilation_condition; + } else { + in_bounds_condition = And(in_bounds_condition, dilation_condition); + } + + // Apply base dilation to the index. + input_index[i] = + SDiv(input_index[i], b_.getInt64(window.dimensions(i).base_dilation())); // We need to check if 0 <= input_index[i] < bound, as otherwise we are in // the padding so that we can skip the computation. That is equivalent to @@ -728,12 +745,6 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { /*operands=*/{reduce_window->operand(0)}, /*supported_types=*/{F32, BF16, S32, F16})); - // TODO(b/31410564): Implement dilation for reduce-window. - if (window_util::HasDilation(reduce_window->window())) { - return Unimplemented( - "Dilation for ReduceWindow is not implemented on CPU."); - } - // Pseudo code for reduce window: // // for (coordinates O in the output) @@ -1398,10 +1409,10 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) { // // So if we reduce f32[A,B,C,D] on dimensions 1 and 2, this map contains // [0->0, 3->1]. - gtl::FlatMap unreduced_dim_map; + absl::flat_hash_map unreduced_dim_map; - gtl::FlatSet reduced_dims(reduce.dimensions().begin(), - reduce.dimensions().end()); + absl::flat_hash_set reduced_dims(reduce.dimensions().begin(), + reduce.dimensions().end()); const Shape& operand_shape = reduce.operand(0)->shape(); const Shape& result_shape = reduce.shape(); @@ -1977,7 +1988,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { // // * Implement the memcpy within the innermost loop. - gtl::FlatSet inner_dims; + absl::flat_hash_set inner_dims; for (int64 dim : LayoutUtil::MinorToMajor(layout)) { if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) { break; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index daafef4eb38f14679e025d8e75dd671e94198102..586f27b104ed706a3b128903c6a90abbf3667e59 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/ADT/Triple.h" @@ -47,7 +48,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -427,7 +427,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Maps the buffer allocation slices for the parameters to the computation // being compiled to their parameter numbers. Only relevant for thread local // computations. - tensorflow::gtl::FlatMap + absl::flat_hash_map computation_parameter_allocations_; // Maps HLO instructions to their index into the profile counter array. @@ -567,11 +567,11 @@ class IrEmitter : public DfsHloVisitorWithDefault, } }; - tensorflow::gtl::FlatMap + absl::flat_hash_map emitted_literals_; - tensorflow::gtl::FlatMap + absl::flat_hash_map constant_buffer_to_global_; std::vector thread_local_computations_; diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 9ec0c8f65705db335379649def746921e6b05bea..f77641eb7da71117092730c1fd5090c61c939813 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -108,15 +108,15 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, [](llvm::Error Err) { cantFail(std::move(Err), "lookupFlags failed"); })), - object_layer_(execution_session_, - [this](llvm::orc::VModuleKey) { - llvm::orc::RTDyldObjectLinkingLayer::Resources result; - result.MemMgr = - std::make_shared( - orc_jit_memory_mapper::GetInstance()); - result.Resolver = symbol_resolver_; - return result; - }), + object_layer_( + execution_session_, + [this](llvm::orc::VModuleKey) { + llvm::orc::LegacyRTDyldObjectLinkingLayer::Resources result; + result.MemMgr = std::make_shared( + orc_jit_memory_mapper::GetInstance()); + result.Resolver = symbol_resolver_; + return result; + }), compile_layer_(object_layer_, CompilerFunctor(target_machine_.get(), &disassembler_, opt_level, optimize_for_size, diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index d74b63fcf45bd70cd18ee41f1e9714ba6a222abd..78406ba143570183aea09d79db3f9b708c21bf70 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -44,9 +44,9 @@ namespace cpu { // it's added to the JIT. class SimpleOrcJIT { public: - using ObjLayerT = llvm::orc::RTDyldObjectLinkingLayer; + using ObjLayerT = llvm::orc::LegacyRTDyldObjectLinkingLayer; using CompileFtor = std::function; - using CompileLayerT = llvm::orc::IRCompileLayer; + using CompileLayerT = llvm::orc::LegacyIRCompileLayer; using VModuleKeyT = llvm::orc::VModuleKey; // Create a new JIT, targeting the host architecture. diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc index a0cd8ee2d2be10bcee9c2e216e24908d949e2d7b..5cdac203af2e7a1f8f3aebda965447ba75e9934e 100644 --- a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" +#include "tensorflow/core/platform/logging.h" namespace xla { namespace cpu { diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.h b/tensorflow/compiler/xla/service/cpu/target_machine_features.h index 8b00ae9e47eeed26ffe80707b89593b267e8dbb8..a383b4a4a00f9b8d49a88e8349793a3a90d8da7b 100644 --- a/tensorflow/compiler/xla/service/cpu/target_machine_features.h +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_ +#include "absl/container/flat_hash_map.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace cpu { @@ -97,8 +97,7 @@ class LLVMTargetMachineFeatures : public TargetMachineFeatures { // This is mutated from within `GetTargetTransformInfoFor` which is // semantically a getter (and thus `const`); and is therefore declared // mutable. Making this mutable is okay because it has cache semantics. - mutable tensorflow::gtl::FlatMap + mutable absl::flat_hash_map target_transform_info_cache_; llvm::TargetMachine* target_machine_; }; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index 7af51db55af44ae1e437ea8e4de7427012cad82f..b35fd9dad877c319c3d0110c96a00aeefa78769e 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -121,7 +121,7 @@ TEST_F(CpuNoAliasTest, Concat) { CHECK: %read_concat2_array = load {{.*}} !alias.scope [[concat1_noalias]], !noalias [[concat1_scope]] CHECK-DAG: [[buf_size32:![0-9]+]] = !{!"buffer:{{.*}} size:32 CHECK-DAG: [[buf_size48:![0-9]+]] = !{!"buffer:{{.*}} size:48 - CHECK-DAG: [[param_x_noalias]] = !{[[buf_size32]], [[buf_size48]]} + CHECK-DAG: [[param_x_noalias]] = !{[[buf_size48]], [[buf_size32]]} CHECK-DAG: [[concat1_scope]] = !{[[buf_size32]]} CHECK-DAG: [[concat1_noalias]] = !{[[buf_size48]]} )"; diff --git a/tensorflow/compiler/xla/service/defuser.cc b/tensorflow/compiler/xla/service/defuser.cc index d124f74d19d83269be96ee34a6b4b2a8d00a978f..661539cccb4ef27a49a73f97a0a8b0d9dfc77061 100644 --- a/tensorflow/compiler/xla/service/defuser.cc +++ b/tensorflow/compiler/xla/service/defuser.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -48,7 +49,7 @@ Status Defuse(HloInstruction* fusion_instruction) { fusion_instruction->fused_instructions_computation(); // A map from fused instruction to its defused clone. - tensorflow::gtl::FlatMap + absl::flat_hash_map defused_instructions; // Initialize map to contain the fusion instruction parameters mapping // to the operands of the fusion instruction. diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc index 3e7373adc5ab8a60fd18348ce2477175aaaa8fd4..c54f81e6915a286757e59821c2684a7271889816 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc @@ -50,7 +50,7 @@ void DfsHloVisitorBase::SetVisiting( const HloInstruction& instruction) { VLOG(3) << "marking HLO " << &instruction << " as visiting: "; DCHECK(NotVisited(instruction)); - visit_state_.SetState(instruction.unique_id(), VisitState::kVisiting); + visit_state_[instruction.unique_id()] = VisitState::kVisiting; } template @@ -58,7 +58,7 @@ void DfsHloVisitorBase::SetVisited( const HloInstruction& instruction) { VLOG(3) << "marking HLO " << &instruction << " as visited: "; DCHECK(NotVisited(instruction) || IsVisiting(instruction)); - visit_state_.SetState(instruction.unique_id(), VisitState::kVisited); + visit_state_[instruction.unique_id()] = VisitState::kVisited; } template diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 5761573791d90e45c65b55124a4bae3c5b929ef1..4159aa281fa2b66d310d7c135f123a5a3bb83270 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -264,21 +264,25 @@ class DfsHloVisitorBase { kVisited = 2, }; - VisitState GetVisitState(int id) { return visit_state_.GetState(id); } + VisitState GetVisitState(int id) { + auto iter = visit_state_.find(id); + if (iter == visit_state_.end()) { + return VisitState::kNotVisited; + } + return iter->second; + } VisitState GetVisitState(const HloInstruction& instruction); // Resize internal state if necessary to hold state for ids <= num. // This call is purely a performance hint and can be omitted without // affecting correctness. - void ReserveVisitStates(int num) { visit_state_.Reserve(num); } + void ReserveVisitStates(int num) { visit_state_.reserve(num); } // Useful when we want to visit the same computation more than once with the // same visitor. - void ResetVisitStates() { visit_state_.Reset(); } + void ResetVisitStates() { visit_state_.clear(); } - void SetVisitState(int id, VisitState state) { - visit_state_.SetState(id, state); - } + void SetVisitState(int id, VisitState state) { visit_state_[id] = state; } // Sets the visitation state of the given instruction as kVisiting. // @@ -327,44 +331,7 @@ class DfsHloVisitorBase { virtual Status Postprocess(HloInstructionPtr hlo); private: - class DFSVisitStates { - public: - DFSVisitStates() {} - void Reserve(uint64 num) { - states_.reserve((num + kStatesPerWord - 1) / kStatesPerWord); - } - VisitState GetState(uint64 id) { - uint64 word_index = id / kStatesPerWord; - if (word_index >= states_.size()) { - return VisitState::kNotVisited; - } - static_assert(static_cast(VisitState::kVisited) < 3, - "VisitState must fit in two bits"); - uint64 w = states_[word_index]; - uint32 shift = 2 * (id % kStatesPerWord); // 2 bits per state - return static_cast((w >> shift) & 0x3); - } - void SetState(uint64 id, VisitState state) { - uint64 word_index = id / kStatesPerWord; - if (word_index >= states_.size()) { - states_.resize(word_index + 1, 0); - } - uint64* w = &states_[word_index]; - uint32 shift = 2 * (id % kStatesPerWord); // 2 bits per state - uint64 mask = 0x3ull << shift; - *w = (*w & ~mask) | (static_cast(state) << shift); - DCHECK_EQ(GetState(id), state); - } - void Reset() { states_.clear(); } - - private: - static const uint32 kStatesPerWord = sizeof(uint64) / 2 /*bits per entry*/; - // Map from id to two-bit states. We store 32 such states per 64-bit - // value - std::vector states_; - }; - - DFSVisitStates visit_state_; + absl::flat_hash_map visit_state_; TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorBase); }; diff --git a/tensorflow/compiler/xla/service/fusion_queue.h b/tensorflow/compiler/xla/service/fusion_queue.h new file mode 100644 index 0000000000000000000000000000000000000000..1208a7dda87d7b2a6ad7113e2604e8b9a0fa045b --- /dev/null +++ b/tensorflow/compiler/xla/service/fusion_queue.h @@ -0,0 +1,53 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { + +// A queue interface that allows implementations to choose fusion candidates in +// custom order. +class FusionQueue { + public: + FusionQueue() = default; + virtual ~FusionQueue() = default; + + // Dequeues the next fusion candidates: a consumer and the list of producers + // as operand indices. + virtual std::pair> + DequeueNextInstructionAndOperandsToFuseInOrder() = 0; + + // A callback passed to the queue implementation right before the producer is + // fused into the consumer. + virtual void PreFusion(HloInstruction* producer, HloInstruction* consumer) {} + + // A callback passed to the queue implementation right after the fusion is + // created. Note that original_producer could have been destroyed. + virtual void OnFusingInstruction(HloInstruction* fusion, + HloInstruction* original_producer, + HloInstruction* original_consumer) {} + + // A callback passed to the queue implementation to notify the removal of an + // instruction. + virtual void RemoveInstruction(HloInstruction* instruction) = 0; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_ diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h index 2b39359aae9fc01f1a88a2594108b2772788e826..8af9c6b71fbc391bf7c0e9809e979b65135a6df3 100644 --- a/tensorflow/compiler/xla/service/gather_expander.h +++ b/tensorflow/compiler/xla/service/gather_expander.h @@ -28,7 +28,7 @@ class GatherExpander : public HloModulePass { absl::string_view name() const override { return "gather_expander"; } StatusOr Run(HloModule* module) override; - private: + protected: StatusOr ExpandGather(HloInstruction* gather_instr); }; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 51968d13d492d6cb1d9731c9c18c7c8e4962c0d5..449fd919d64612ba10932c7cc0865f1fca96424a 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -25,6 +25,10 @@ filegroup( ) load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) xla_proto_library( name = "backend_configs", @@ -91,6 +95,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_reachability", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", ], ) @@ -153,7 +158,7 @@ cc_library( deps = [ ":backend_configs", ":buffer_allocations", - ":cudnn_convolution_runner", + ":cudnn_conv_runner", ":elemental_ir_emitter", ":gpu_constants", ":gpu_executable", @@ -322,7 +327,7 @@ cc_library( ], deps = [ ":buffer_allocations", - ":cudnn_convolution_runner", + ":cudnn_conv_runner", ":hlo_execution_profiler", ":infeed_manager", ":ir_emission_utils", @@ -357,6 +362,7 @@ cc_library( "//tensorflow/core/platform/default/build_config:cufft_plugin", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep "//tensorflow/stream_executor", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -383,13 +389,13 @@ cc_library( ) cc_library( - name = "cudnn_convolution_algorithm_picker", - srcs = ["cudnn_convolution_algorithm_picker.cc"], - hdrs = ["cudnn_convolution_algorithm_picker.h"], + name = "cudnn_conv_algorithm_picker", + srcs = ["cudnn_conv_algorithm_picker.cc"], + hdrs = ["cudnn_conv_algorithm_picker.h"], deps = [ ":backend_configs", ":buffer_comparator", - ":cudnn_convolution_runner", + ":cudnn_conv_runner", ":gpu_executable", ":ir_emission_utils", "//tensorflow/compiler/xla:literal_util", @@ -402,14 +408,15 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", ], ) cc_library( - name = "cudnn_convolution_runner", - srcs = ["cudnn_convolution_runner.cc"], - hdrs = ["cudnn_convolution_runner.h"], + name = "cudnn_conv_runner", + srcs = ["cudnn_conv_runner.cc"], + hdrs = ["cudnn_conv_runner.h"], deps = [ ":backend_configs", ":ir_emission_utils", @@ -429,9 +436,9 @@ cc_library( ) cc_library( - name = "cudnn_convolution_rewriter", - srcs = ["cudnn_convolution_rewriter.cc"], - hdrs = ["cudnn_convolution_rewriter.h"], + name = "cudnn_conv_rewriter", + srcs = ["cudnn_conv_rewriter.cc"], + hdrs = ["cudnn_conv_rewriter.h"], deps = [ ":backend_configs", ":ir_emission_utils", @@ -446,10 +453,10 @@ cc_library( ) tf_cc_test( - name = "cudnn_convolution_rewriter_test", - srcs = ["cudnn_convolution_rewriter_test.cc"], + name = "cudnn_conv_rewriter_test", + srcs = ["cudnn_conv_rewriter_test.cc"], deps = [ - ":cudnn_convolution_rewriter", + ":cudnn_conv_rewriter", ":ir_emission_utils", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", @@ -474,6 +481,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:instruction_fusion", "//tensorflow/compiler/xla/service:pattern_matcher", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -506,6 +514,7 @@ cc_library( "//tensorflow/compiler/xla/service:multi_output_fusion", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -539,6 +548,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -575,9 +585,9 @@ tf_cc_test( ) cc_library( - name = "pad_insertion", - srcs = ["pad_insertion.cc"], - hdrs = ["pad_insertion.h"], + name = "cudnn_conv_padding_legalization", + srcs = ["cudnn_conv_padding_legalization.cc"], + hdrs = ["cudnn_conv_padding_legalization.h"], deps = [ ":ir_emission_utils", "//tensorflow/compiler/xla:literal", @@ -585,6 +595,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_creation_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:shape_inference", @@ -593,9 +604,9 @@ cc_library( ) cc_library( - name = "pad_for_tensor_cores", - srcs = ["pad_for_tensor_cores.cc"], - hdrs = ["pad_for_tensor_cores.h"], + name = "cudnn_conv_pad_for_tensor_cores", + srcs = ["cudnn_conv_pad_for_tensor_cores.cc"], + hdrs = ["cudnn_conv_pad_for_tensor_cores.h"], deps = [ ":ir_emission_utils", "//tensorflow/compiler/xla:literal_util", @@ -607,11 +618,11 @@ cc_library( ) tf_cc_test( - name = "pad_for_tensor_cores_test", - srcs = ["pad_for_tensor_cores_test.cc"], + name = "cudnn_conv_pad_for_tensor_cores_test", + srcs = ["cudnn_conv_pad_for_tensor_cores_test.cc"], deps = [ + ":cudnn_conv_pad_for_tensor_cores", ":ir_emission_utils", - ":pad_for_tensor_cores", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo_matchers", @@ -653,9 +664,11 @@ cc_library( srcs = ["nvptx_compiler.cc"], hdrs = ["nvptx_compiler.h"], deps = [ - ":cudnn_convolution_algorithm_picker", - ":cudnn_convolution_rewriter", - ":cudnn_fused_convolution_rewriter", + ":cudnn_conv_algorithm_picker", + ":cudnn_conv_pad_for_tensor_cores", + ":cudnn_conv_padding_legalization", + ":cudnn_conv_rewriter", + ":cudnn_fused_conv_rewriter", ":fusion_merger", ":gpu_constants", ":gpu_copy_insertion", @@ -667,8 +680,6 @@ cc_library( ":ir_emission_utils", ":ir_emitter", ":multi_output_fusion", - ":pad_for_tensor_cores", - ":pad_insertion", ":partition_assignment", ":stream_assignment", ":stream_executor_util", @@ -699,7 +710,6 @@ cc_library( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", - "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_constant_sinking", @@ -713,6 +723,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -774,7 +785,6 @@ cc_library( srcs = ["gpu_layout_assignment.cc"], hdrs = ["gpu_layout_assignment.h"], deps = [ - ":gpu_options", ":ir_emission_utils", ":stream_executor_util", "//tensorflow/compiler/xla:shape_util", @@ -875,16 +885,6 @@ cc_library( ], ) -cc_library( - name = "gpu_options", - srcs = ["gpu_options.cc"], - hdrs = ["gpu_options.h"], - deps = [ - "//tensorflow/compiler/xla/service:hlo_module_config", - "//tensorflow/core:lib_internal", - ], -) - cc_library( name = "stream_executor_util", srcs = ["stream_executor_util.cc"], @@ -970,9 +970,9 @@ tf_cc_test( ) cc_library( - name = "cudnn_fused_convolution_rewriter", - srcs = ["cudnn_fused_convolution_rewriter.cc"], - hdrs = ["cudnn_fused_convolution_rewriter.h"], + name = "cudnn_fused_conv_rewriter", + srcs = ["cudnn_fused_conv_rewriter.cc"], + hdrs = ["cudnn_fused_conv_rewriter.h"], deps = [ ":backend_configs", ":ir_emission_utils", @@ -984,3 +984,18 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", ], ) + +tf_cc_test( + name = "cudnn_fused_conv_rewriter_test", + srcs = ["cudnn_fused_conv_rewriter_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", + "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 4effea637d01bf23b54d341b77306b20b1b133c8..e1dffad3045808c4f316ccafdda39a174e1560c8 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/types.h" @@ -56,9 +56,9 @@ Status ConvolutionThunk::ExecuteOnStream( buffer_allocations.GetDeviceAddress(scratch_buffer_); auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); - TF_RETURN_IF_ERROR(RunCudnnConvolution(cudnn_call_, - absl::MakeSpan(operand_se_buffers), - result_buffer, scratch, stream)); + TF_RETURN_IF_ERROR(RunCudnnConv(cudnn_call_, + absl::MakeSpan(operand_se_buffers), + result_buffer, scratch, stream)); void* ptrs[] = {result_buffer.opaque(), scratch.opaque()}; se::DeviceMemory tuple_addr( diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index f53bc541983378819dba36489dd69c348f50af32..c71515490c94ef54baad9005509d1813de630159 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -19,7 +19,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc similarity index 90% rename from tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc index 7125673887d28729287d67577bcfa06423f85611..6d6780fa1c7b0c636eb771c40e74f074cd8c4c4b 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/types/optional.h" @@ -145,9 +145,8 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { // cache misses and doing extra work. Overall, caching doesn't seem worth the // trouble, but we may want to revisit this if we ever find a model where // caching would speed up compilation a lot. -StatusOr> -CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( - HloCustomCallInstruction* instr) { +StatusOr +CudnnConvAlgorithmPicker::PickBestAlgorithm(HloCustomCallInstruction* instr) { // TODO(timshen): for now only check fp16. It can be expanded to other types, // with some work on the HLO routines. const bool cross_check_enabled = @@ -253,10 +252,10 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( backend_config.set_algorithm(alg.algo_id()); backend_config.set_tensor_ops_enabled(alg.tensor_ops_enabled()); TF_RETURN_IF_ERROR(instr->set_backend_config(backend_config)); - bool launch_ok = RunCudnnConvolution(instr, absl::MakeSpan(operand_buffers), - result_buffer, &scratch_allocator, - &stream, &profile_result) - .ok(); + bool launch_ok = + RunCudnnConv(instr, absl::MakeSpan(operand_buffers), result_buffer, + &scratch_allocator, &stream, &profile_result) + .ok(); if (launch_ok && profile_result.is_valid()) { const bool crash_on_checking_failure = @@ -316,9 +315,10 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( << AlgorithmToString(best_result.algorithm()) << ", takes " << best_result.elapsed_time_in_ms() << "ms, and uses " << best_result_bytes_used << "B of scratch memory."; - return std::make_tuple(best_result.algorithm().algo_id(), - best_result.algorithm().tensor_ops_enabled(), - best_result_bytes_used); + return AutotuneResult{best_result.algorithm().algo_id(), + best_result.algorithm().tensor_ops_enabled(), + best_result_bytes_used, + absl::Milliseconds(best_result.elapsed_time_in_ms())}; } return InternalError( @@ -327,45 +327,41 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( instr->ToString()); } -StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( +StatusOr CudnnConvAlgorithmPicker::RunOnInstruction( HloInstruction* instr) { CHECK(IsCustomCallToDnnConvolution(*instr)); - StatusOr> alg_scratch_and_tc = + StatusOr best_algo_or = PickBestAlgorithm(Cast(instr)); - - if (!alg_scratch_and_tc.ok()) { - LOG(ERROR) << alg_scratch_and_tc.status(); + if (!best_algo_or.ok()) { + LOG(ERROR) << best_algo_or.status(); return false; } - int64 algorithm; - bool tensor_ops_enabled; - int64 scratch_bytes; - - std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = - alg_scratch_and_tc.ConsumeValueOrDie(); - - VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and " - << NumBytesToString(scratch_bytes) + auto best_algo = std::move(best_algo_or).ValueOrDie(); + VLOG(1) << "Setting cudnn conv to use algorithm " << best_algo.algorithm + << " and " << NumBytesToString(best_algo.scratch_bytes) << " of scratch memory: " << instr->ToString() - << " tensor_ops_enabled: " << tensor_ops_enabled; + << " tensor_ops_enabled: " << best_algo.tensor_ops_enabled; // Replace instr with a new CustomCall which has the correct algorithm, and // whose output shape has the appropriate amount of scratch memory. HloComputation* computation = instr->parent(); - Shape new_call_shape = - ShapeUtil::MakeTupleShape({instr->shape().tuple_shapes(0), - ShapeUtil::MakeShape(U8, {scratch_bytes})}); + Shape new_call_shape = ShapeUtil::MakeTupleShape( + {instr->shape().tuple_shapes(0), + ShapeUtil::MakeShape(U8, {best_algo.scratch_bytes})}); TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, instr->backend_config()); - backend_config.set_algorithm(algorithm); - backend_config.set_tensor_ops_enabled(tensor_ops_enabled); + backend_config.set_algorithm(best_algo.algorithm); + backend_config.set_tensor_ops_enabled(best_algo.tensor_ops_enabled); HloInstruction* new_call = computation->AddInstruction( instr->CloneWithNewOperands(new_call_shape, instr->operands())); + VLOG(1) << "Replacing convolution " << instr->ToString() << " with " + << new_call->ToString(); + TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config)); // Repackage new_call so it has the same shape as the original call, namely @@ -381,7 +377,7 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( return true; } -StatusOr CudnnConvolutionAlgorithmPicker::RunOnComputation( +StatusOr CudnnConvAlgorithmPicker::RunOnComputation( HloComputation* computation) { std::vector convs; for (auto* instr : computation->instructions()) { @@ -398,7 +394,7 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnComputation( return changed; } -StatusOr CudnnConvolutionAlgorithmPicker::Run(HloModule* module) { +StatusOr CudnnConvAlgorithmPicker::Run(HloModule* module) { bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations()) { TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h similarity index 71% rename from tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h rename to tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h index aeda2fc7f8b4d6169fc2baa8975119ba7bf68dd2..642af787afc71586d722ecc7e529ed8b3fa64d33 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h @@ -13,13 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_ALGORITHM_PICKER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_ALGORITHM_PICKER_H_ +#include "absl/time/time.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -30,27 +31,32 @@ namespace gpu { // Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for // each and adding explicit scratch space to the CustomCalls. -class CudnnConvolutionAlgorithmPicker : public HloModulePass { +class CudnnConvAlgorithmPicker : public HloModulePass { public: // If the `allocator` parameter is not null, we will use it to allocate temp // memory while timing the various convolution algorithms. If it's null, // we'll use the default allocator on the StreamExecutor. - CudnnConvolutionAlgorithmPicker(se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* allocator, - Compiler* compiler) + CudnnConvAlgorithmPicker(se::StreamExecutor* stream_exec, + DeviceMemoryAllocator* allocator, Compiler* compiler) : stream_exec_(stream_exec), allocator_(allocator), compiler_(compiler) {} absl::string_view name() const override { - return "cudnn-convolution-algorithm-picker"; + return "cudnn-conv-algorithm-picker"; } StatusOr Run(HloModule* module) override; private: + struct AutotuneResult { + int64 algorithm; + bool tensor_ops_enabled; + int64 scratch_bytes; + absl::Duration runtime; + }; + StatusOr RunOnComputation(HloComputation* computation); StatusOr RunOnInstruction(HloInstruction* instr); - StatusOr> PickBestAlgorithm( - HloCustomCallInstruction* instr); + StatusOr PickBestAlgorithm(HloCustomCallInstruction* instr); se::StreamExecutor* stream_exec_; // never null DeviceMemoryAllocator* allocator_; // may be null @@ -60,4 +66,4 @@ class CudnnConvolutionAlgorithmPicker : public HloModulePass { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_ALGORITHM_PICKER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc similarity index 51% rename from tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc index e3869b5c368957571219a39600214140022a7318..5aa4f839f4be5f1060480fea98775f8ffada0bdd 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -24,50 +24,17 @@ limitations under the License. namespace xla { namespace gpu { -// We want the input/output feature counts of an f16 conv to be factors of 8, -// because without this cudnn can't use tensor cores on the conv. -static constexpr int64 kDesiredNumFeaturesFactor = 8; - // We won't pad a conv if doing so increases the total number of bytes in the // lhs, rhs, or result by more than this amount. // // TODO(jlebar): This number was tuned experimentally. It represents a // compromise on our current benchmarks; it speeds some up significantly, and // doesn't slow any down. But we can observe by changing this value that -// there's additional room for speedups. Achieving those speedups without also -// slowing other things down will likely require a more sophisticated heuristic, -// possibly some form of auto-tuning. -// -// This value should be >= 4/3, otherwise the "dims of size 3 padded up to 4" -// special case inside PadShape won't fire. +// there's additional room for speedups. Achieving those speedups without +// also slowing other things down will likely require a more sophisticated +// heuristic, possibly some form of auto-tuning. static constexpr double kMaxBytesTouchedIncrease = 1.35; -// Pads the given dimensions in the given shape up to a multiple of -// kDesiredNumFeaturesFactor. -static Shape PadShape(Shape s, absl::Span dims) { - for (int64 dim : dims) { - int64 dim_to_pad_size = s.dimensions(dim); - - // Round dim_to_pad_size up to the next multiple of - // kDesiredNumFeaturesFactor. - // - // Special case: dims of size 3 are rounded up to 4, not - // kDesiredNumFeaturesFactor. Empirically (and on the advice of nvidia), - // this helps, but as of writing, it's not supported by anything in the - // cudnn docs. - int64 new_dim_to_pad_size; - if (dim_to_pad_size == 3) { - new_dim_to_pad_size = 4; - } else { - new_dim_to_pad_size = - RoundUpToNearest(dim_to_pad_size, kDesiredNumFeaturesFactor); - } - - s.set_dimensions(dim, new_dim_to_pad_size); - } - return s; -} - // Creates and returns an HLO that zero-pads one or more dimensions in the given // instruction so that its shape is equal to the given shape. // @@ -103,90 +70,19 @@ static HloInstruction* PadInstruction(HloInstruction* instr, HloInstruction::CreatePad(new_shape, instr, zero, pad_config)); } -// Pads the input/output feature dimensions of the given cudnn convolution -// custom-call to be multiples of kDesiredNumFeaturesFactor. -static StatusOr PadFeaturesDims(HloInstruction* conv) { +// Modifies the given convolution to have the given LHS/RHS/result shapes. +static Status PadConv(HloCustomCallInstruction* conv, + const Shape& new_lhs_shape, const Shape& new_rhs_shape, + const Shape& new_result_shape) { CHECK_EQ(0, conv->shape().tuple_shapes(1).dimensions(0)) << "conv must use 0 scratch bytes, i.e. this pass must be run " - "before CudnnConvolutionAlgorithmPicker."; + "before CudnnConvAlgorithmPicker."; - const auto& target = conv->custom_call_target(); - const auto& dnums = conv->convolution_dimension_numbers(); auto* lhs = conv->mutable_operand(0); auto* rhs = conv->mutable_operand(1); - const Shape& result_shape = conv->shape().tuple_shapes(0); - - Shape new_lhs_shape = [&] { - if (target == kCudnnConvForwardCallTarget || - target == kCudnnConvBackwardFilterCallTarget) { - // LHS is "input". - return PadShape(lhs->shape(), {dnums.input_feature_dimension()}); - } - CHECK_EQ(target, kCudnnConvBackwardInputCallTarget); - // LHS is "output". - return PadShape(lhs->shape(), {dnums.output_feature_dimension()}); - }(); - - Shape new_rhs_shape = [&] { - if (target == kCudnnConvForwardCallTarget || - target == kCudnnConvBackwardInputCallTarget) { - // RHS is "filter". - return PadShape(rhs->shape(), {dnums.kernel_input_feature_dimension(), - dnums.kernel_output_feature_dimension()}); - } - CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget); - // RHS is "output". - return PadShape(rhs->shape(), {dnums.output_feature_dimension()}); - }(); - - if (ShapeUtil::Equal(lhs->shape(), new_lhs_shape) && - ShapeUtil::Equal(rhs->shape(), new_rhs_shape)) { - VLOG(3) << "No need to pad features of " << conv->ToString(); - return false; - } - - Shape new_result_shape = [&] { - if (target == kCudnnConvForwardCallTarget) { - // Result is "output". - return PadShape(result_shape, {dnums.output_feature_dimension()}); - } - if (target == kCudnnConvBackwardInputCallTarget) { - // Result is "input". - return PadShape(result_shape, {dnums.input_feature_dimension()}); - } - CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget); - // Result is "filter". - return PadShape(result_shape, {dnums.kernel_input_feature_dimension(), - dnums.kernel_output_feature_dimension()}); - }(); - - // Check that padding wouldn't increase the total bytes read/written by this - // operation too much. - auto check_size_increase = [&](const Shape& old_shape, - const Shape& new_shape) { - int64 old_bytes = ShapeUtil::ByteSizeOf(old_shape); - int64 new_bytes = ShapeUtil::ByteSizeOf(new_shape); - if (new_bytes <= old_bytes * kMaxBytesTouchedIncrease) { - return true; - } - VLOG(3) << "Not padding convolution; doing so would change input / result " - "shape from " - << ShapeUtil::HumanString(old_shape) << " to " - << ShapeUtil::HumanString(new_shape) << ", a size increase of " - << new_bytes / static_cast(old_bytes) << "x > " - << kMaxBytesTouchedIncrease << "x: " << conv->ToString(); - return false; - }; - if (!check_size_increase(lhs->shape(), new_lhs_shape) || - !check_size_increase(rhs->shape(), new_rhs_shape) || - !check_size_increase(result_shape, new_result_shape)) { - return false; - } - - // OK, let's do the transformation! - auto* new_lhs = PadInstruction(lhs, new_lhs_shape); auto* new_rhs = PadInstruction(rhs, new_rhs_shape); + const Shape& result_shape = conv->shape().tuple_shapes(0); CHECK(new_lhs != lhs || new_rhs != rhs) << "We should have had to pad either LHS or RHS."; @@ -219,30 +115,124 @@ static StatusOr PadFeaturesDims(HloInstruction* conv) { VLOG(2) << "Padded features of " << conv->ToString() << ", replaced with " << new_conv->ToString(); - TF_RETURN_IF_ERROR(conv->parent()->ReplaceInstruction(conv, new_conv)); + return conv->parent()->ReplaceInstruction(conv, new_conv); +} + +static StatusOr PadForTensorCores(HloCustomCallInstruction* conv) { + TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv)); + const auto& dnums = conv->convolution_dimension_numbers(); + auto* lhs = conv->mutable_operand(0); + auto* rhs = conv->mutable_operand(1); + const Shape& result_shape = conv->shape().tuple_shapes(0); + + // Nothing to do on non-f16 convolutions. + if (result_shape.element_type() != PrimitiveType::F16) { + return false; + } + + // TODO(timshen): Don't skip forward-activation convs if we find a benchmark + // where there's a speedup. + if (kind == CudnnConvKind::kForwardActivation) { + return false; + } + + Shape new_lhs_shape = lhs->shape(); + Shape new_rhs_shape = rhs->shape(); + Shape new_result_shape = conv->shape().tuple_shapes(0); + + // new_{input,filter_output}_shape points to the appropriate one of + // new_{lhs,rhs,result}_shape. + Shape* new_input_shape; + Shape* new_filter_shape; + Shape* new_output_shape; + std::tie(new_input_shape, new_filter_shape, new_output_shape) = [&] { + switch (kind) { + case CudnnConvKind::kForward: + case CudnnConvKind::kForwardActivation: + return std::make_tuple(&new_lhs_shape, &new_rhs_shape, + &new_result_shape); + case CudnnConvKind::kBackwardInput: + return std::make_tuple(&new_result_shape, &new_rhs_shape, + &new_lhs_shape); + case CudnnConvKind::kBackwardFilter: + return std::make_tuple(&new_lhs_shape, &new_result_shape, + &new_rhs_shape); + } + }(); + + // If there are 3 input features and 32 or 64 output features, pad the input + // features to 4. Otherwise, try padding to multiples of 8 and check that + // this doesn't make any of the conv buffers too much larger. + auto input_features = + new_input_shape->dimensions(dnums.input_feature_dimension()); + auto output_features = + new_output_shape->dimensions(dnums.output_feature_dimension()); + if (input_features == 3 && (output_features == 32 || output_features == 64)) { + new_input_shape->set_dimensions(dnums.input_feature_dimension(), 4); + new_filter_shape->set_dimensions(dnums.kernel_input_feature_dimension(), 4); + } else { + auto pad_dim = [](Shape* s, int64 dim) { + s->set_dimensions(dim, RoundUpToNearest(s->dimensions(dim), 8)); + }; + pad_dim(new_input_shape, dnums.input_feature_dimension()); + pad_dim(new_filter_shape, dnums.kernel_input_feature_dimension()); + pad_dim(new_filter_shape, dnums.kernel_output_feature_dimension()); + pad_dim(new_output_shape, dnums.output_feature_dimension()); + + // Check that padding wouldn't increase the total bytes read/written by this + // operation too much. + auto check_size_increase = [&](const Shape& old_shape, + const Shape& new_shape) { + int64 old_bytes = ShapeUtil::ByteSizeOf(old_shape); + int64 new_bytes = ShapeUtil::ByteSizeOf(new_shape); + if (new_bytes <= old_bytes * kMaxBytesTouchedIncrease) { + return true; + } + VLOG(3) + << "Not padding convolution; doing so would change input / result " + "shape from " + << ShapeUtil::HumanString(old_shape) << " to " + << ShapeUtil::HumanString(new_shape) << ", a size increase of " + << new_bytes / static_cast(old_bytes) << "x > " + << kMaxBytesTouchedIncrease << "x: " << conv->ToString(); + return false; + }; + + if (!check_size_increase(lhs->shape(), new_lhs_shape) || + !check_size_increase(rhs->shape(), new_rhs_shape) || + !check_size_increase(result_shape, new_result_shape)) { + return false; + } + } + + if (ShapeUtil::Equal(lhs->shape(), new_lhs_shape) && + ShapeUtil::Equal(rhs->shape(), new_rhs_shape)) { + VLOG(3) << "No need to pad features of " << conv->ToString(); + return false; + } + + // OK, let's do the transformation! + TF_RETURN_IF_ERROR( + PadConv(conv, new_lhs_shape, new_rhs_shape, new_result_shape)); return true; } -static std::vector GetRelevantConvs(HloComputation* comp) { - std::vector convs; +static std::vector GetRelevantConvs( + HloComputation* comp) { + std::vector convs; for (HloInstruction* instr : comp->instructions()) { - if (IsCustomCallToDnnConvolution(*instr) && - instr->operand(0)->shape().element_type() == F16 && - // TODO(timshen): Disable for fused conv for now. Implement it if it's - // needed. - Cast(instr)->custom_call_target() != - kCudnnConvBiasActivationForwardCallTarget) { - convs.push_back(instr); + if (IsCustomCallToDnnConvolution(*instr)) { + convs.push_back(Cast(instr)); } } return convs; } -StatusOr PadForTensorCores::Run(HloModule* module) { +StatusOr CudnnConvPadForTensorCores::Run(HloModule* module) { bool changed = false; for (HloComputation* comp : module->MakeNonfusionComputations()) { - for (HloInstruction* conv : GetRelevantConvs(comp)) { - TF_ASSIGN_OR_RETURN(bool result, PadFeaturesDims(conv)); + for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) { + TF_ASSIGN_OR_RETURN(bool result, PadForTensorCores(conv)); changed |= result; } } diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h similarity index 51% rename from tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h rename to tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h index e592a3774ec28605fda912298c74ca7976ff99ac..d4e51e86c1bf2c1f9aef2eed642604092033a538 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h @@ -13,26 +13,30 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PAD_FOR_TENSOR_CORES_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PAD_FOR_TENSOR_CORES_H_ #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" namespace xla { namespace gpu { -// Ensures that f16 cudnn convolutions have input/output channel dimensions that -// are multiples of 8, inserting pads/slices as necessary. +// Adds padding to cudnn convolutions to make them run faster on GPUs with +// tensor cores. // -// This is useful primarily for Volta and newer GPUs, where tensor cores can -// only be used if the channel dims are multiples of 8. It's probably the -// opposite of useful on other GPUs, so you should check what GPU you're -// targeting before running this pass. +// - f16 convolutions are padded to have input/output channel dimensions that +// are multiples of 8, so that we can use tensor cores. +// +// - f16 convolutions with 3 input channels and 32 or 64 output channels are +// padded to 4 input channels. There's a special-cased cudnn algorithm just +// for this. +// +// Don't run this pass on GPUs without tensor cores -- it will make them slower! // // TODO(jlebar): Also pad dots. -class PadForTensorCores : public HloModulePass { +class CudnnConvPadForTensorCores : public HloModulePass { public: - absl::string_view name() const override { return "pad for tensor cores"; } + absl::string_view name() const override { return "cudnn-conv-pad-for-speed"; } StatusOr Run(HloModule* module) override; }; @@ -40,4 +44,4 @@ class PadForTensorCores : public HloModulePass { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PAD_FOR_TENSOR_CORES_H_ diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc similarity index 63% rename from tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc index 5c92b0dcb873b873074704dca8f27d4067b070df..fa3afa6a5d318c399dc38e8934199b5a1393669e 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -29,10 +29,10 @@ namespace { namespace op = xla::testing::opcode_matchers; using ::testing::_; -class PadForTensorCoresTest : public HloVerifiedTestBase {}; +class CudnnConvPadForTensorCoresTest : public HloVerifiedTestBase {}; -TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) { - ParseAndVerifyModule(R"( +TEST_F(CudnnConvPadForTensorCoresTest, PadF16ForwardConvInputChannels) { + auto module = ParseAndReturnVerifiedModule(R"( HloModule TestModule ENTRY TestComputation { @@ -41,11 +41,12 @@ TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) { ROOT result = (f16[10,20,30,40], u8[0]) custom-call(input, filter), window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward" - })"); - EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); - auto* root = module().entry_computation()->root_instruction(); + })") + .ValueOrDie(); + EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + auto* root = module->entry_computation()->root_instruction(); - SCOPED_TRACE(module().ToString()); + SCOPED_TRACE(module->ToString()); EXPECT_THAT(root, op::CustomCall(kCudnnConvForwardCallTarget, op::Pad(op::Parameter(0), _), op::Pad(op::Parameter(1), _))); @@ -55,8 +56,8 @@ TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) { ShapeUtil::MakeShape(F16, {2, 2, 48, 40}))); } -TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvOutputChannels) { - ParseAndVerifyModule(R"( +TEST_F(CudnnConvPadForTensorCoresTest, PadF16BackwardInputConvOutputChannels) { + auto module = ParseAndReturnVerifiedModule(R"( HloModule TestModule ENTRY TestComputation { @@ -65,9 +66,10 @@ TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvOutputChannels) { ROOT result = (f16[10,20,30,40], u8[0]) custom-call(output, filter), window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBackwardInput" - })"); - EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); - auto* root = module().entry_computation()->root_instruction(); + })") + .ValueOrDie(); + EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + auto* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::CustomCall(kCudnnConvBackwardInputCallTarget, op::Pad(op::Parameter(0), _), op::Pad(op::Parameter(1), _))); @@ -77,8 +79,8 @@ TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvOutputChannels) { ShapeUtil::MakeShape(F16, {2, 2, 40, 48}))); } -TEST_F(PadForTensorCoresTest, PadF16ForwardConvOutputChannels) { - ParseAndVerifyModule(R"( +TEST_F(CudnnConvPadForTensorCoresTest, PadF16ForwardConvOutputChannels) { + auto module = ParseAndReturnVerifiedModule(R"( HloModule TestModule ENTRY TestComputation { @@ -87,17 +89,18 @@ TEST_F(PadForTensorCoresTest, PadF16ForwardConvOutputChannels) { ROOT result = (f16[10,20,30,41], u8[0]) custom-call(input, filter), window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward" - })"); - EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); - auto* root = module().entry_computation()->root_instruction(); + })") + .ValueOrDie(); + EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + auto* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Tuple(op::Slice(op::GetTupleElement(op::CustomCall( kCudnnConvForwardCallTarget, op::Parameter(0), op::Pad(op::Parameter(1), _)))), _)); } -TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvInputChannels) { - ParseAndVerifyModule(R"( +TEST_F(CudnnConvPadForTensorCoresTest, PadF16BackwardInputConvInputChannels) { + auto module = ParseAndReturnVerifiedModule(R"( HloModule TestModule ENTRY TestComputation { @@ -107,9 +110,10 @@ TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvInputChannels) { window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBackwardInput" ROOT gte = f16[10,20,30,41] get-tuple-element(result), index=0 - })"); - EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); - auto* root = module().entry_computation()->root_instruction(); + })") + .ValueOrDie(); + EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + auto* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::GetTupleElement(op::Tuple( op::Slice(op::GetTupleElement(op::CustomCall( kCudnnConvBackwardInputCallTarget, op::Parameter(0), @@ -117,8 +121,8 @@ TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvInputChannels) { _))); } -TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvInputChannels) { - ParseAndVerifyModule(R"( +TEST_F(CudnnConvPadForTensorCoresTest, PadF16BackwardFilterConvInputChannels) { + auto module = ParseAndReturnVerifiedModule(R"( HloModule TestModule ENTRY TestComputation { @@ -128,9 +132,10 @@ TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvInputChannels) { window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBackwardFilter" ROOT gte = f16[2,2,41,40] get-tuple-element(result), index=0 - })"); - EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); - auto* root = module().entry_computation()->root_instruction(); + })") + .ValueOrDie(); + EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + auto* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::GetTupleElement(op::Tuple( op::Slice(op::GetTupleElement(op::CustomCall( kCudnnConvBackwardFilterCallTarget, @@ -138,8 +143,8 @@ TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvInputChannels) { _))); } -TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvOutputChannels) { - ParseAndVerifyModule(R"( +TEST_F(CudnnConvPadForTensorCoresTest, PadF16BackwardFilterConvOutputChannels) { + auto module = ParseAndReturnVerifiedModule(R"( HloModule TestModule ENTRY TestComputation { @@ -149,9 +154,10 @@ TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvOutputChannels) { window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBackwardFilter" ROOT gte = f16[2,2,40,41] get-tuple-element(result), index=0 - })"); - EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); - auto* root = module().entry_computation()->root_instruction(); + })") + .ValueOrDie(); + EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + auto* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::GetTupleElement(op::Tuple( op::Slice(op::GetTupleElement(op::CustomCall( kCudnnConvBackwardFilterCallTarget, @@ -159,6 +165,31 @@ TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvOutputChannels) { _))); } +TEST_F(CudnnConvPadForTensorCoresTest, PadInputFeatures3To4) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,3] parameter(0) + filter = f16[2,2,3,32] parameter(1) + ROOT result = (f16[10,20,30,32], u8[0]) custom-call(input, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convForward" + })") + .ValueOrDie(); + EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + auto* root = module->entry_computation()->root_instruction(); + + SCOPED_TRACE(module->ToString()); + EXPECT_THAT(root, op::CustomCall(kCudnnConvForwardCallTarget, + op::Pad(op::Parameter(0), _), + op::Pad(op::Parameter(1), _))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(0)->shape(), + ShapeUtil::MakeShape(F16, {10, 20, 30, 4}))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->shape(), + ShapeUtil::MakeShape(F16, {2, 2, 4, 32}))); +} + } // anonymous namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc similarity index 93% rename from tensorflow/compiler/xla/service/gpu/pad_insertion.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc index b42a19e3a2200e917f8040be183b8d79c9e4e161..d7829045cc127deaa4c2c9b705dca5285d704af2 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc @@ -13,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" @@ -131,7 +132,8 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, } } // namespace -bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) { +bool CudnnConvPaddingLegalization::CanonicalizeForwardConvolution( + HloInstruction* conv) { if (IsForwardConvolutionCanonical(*conv)) { return false; } @@ -186,7 +188,7 @@ void IncreasePaddingHighBy(int64 delta, WindowDimension* window_dim) { } } // namespace -bool PadInsertion::CanonicalizeBackwardFilterConvolution( +bool CudnnConvPaddingLegalization::CanonicalizeBackwardFilterConvolution( HloInstruction* backward_conv) { CHECK_EQ(backward_conv->custom_call_target(), kCudnnConvBackwardFilterCallTarget); @@ -259,7 +261,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( return true; } -bool PadInsertion::CanonicalizeBackwardInputConvolution( +bool CudnnConvPaddingLegalization::CanonicalizeBackwardInputConvolution( HloInstruction* backward_conv) { if (window_util::HasSymmetricPadding(backward_conv->window())) { return false; @@ -376,32 +378,33 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( return true; } -StatusOr PadInsertion::RunOnComputation(HloComputation* computation) { +StatusOr CudnnConvPaddingLegalization::RunOnComputation( + HloComputation* computation) { bool changed = false; - std::vector convs; + std::vector convs; for (auto* instr : computation->instructions()) { if (IsCustomCallToDnnConvolution(*instr)) { - convs.push_back(instr); + convs.push_back(Cast(instr)); } } - for (HloInstruction* instruction : convs) { - const auto& target = instruction->custom_call_target(); - if (target == kCudnnConvForwardCallTarget || - target == kCudnnConvBiasActivationForwardCallTarget) { - changed |= CanonicalizeForwardConvolution(instruction); - } else if (target == kCudnnConvBackwardFilterCallTarget) { - changed |= CanonicalizeBackwardFilterConvolution(instruction); - } else if (target == kCudnnConvBackwardInputCallTarget) { - changed |= CanonicalizeBackwardInputConvolution(instruction); - } else { - LOG(FATAL) << "Unknown custom call target for cudnn conv: " - << instruction->ToString(); - } + for (HloCustomCallInstruction* instruction : convs) { + TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instruction)); + changed |= [&] { + switch (kind) { + case CudnnConvKind::kForward: + case CudnnConvKind::kForwardActivation: + return CanonicalizeForwardConvolution(instruction); + case CudnnConvKind::kBackwardInput: + return CanonicalizeBackwardInputConvolution(instruction); + case CudnnConvKind::kBackwardFilter: + return CanonicalizeBackwardFilterConvolution(instruction); + } + }(); } return changed; } -StatusOr PadInsertion::Run(HloModule* module) { +StatusOr CudnnConvPaddingLegalization::Run(HloModule* module) { bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations()) { TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h similarity index 78% rename from tensorflow/compiler/xla/service/gpu/pad_insertion.h rename to tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h index 25cdf64c4cf01300869044d3e4d7c34c85626a5a..7d1b075517fb285222506e0420984906579e681f 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_INSERTION_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_INSERTION_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PADDING_LEGALIZATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PADDING_LEGALIZATION_H_ #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -24,9 +24,11 @@ namespace gpu { // An HLO pass that canonicalizes convolution instructions for GPU codegen. It // inserts Pad instructions before Convolution instructions with uncanonicalized // padding, so that they can be lowered to cuDNN convolution. -class PadInsertion : public HloModulePass { +class CudnnConvPaddingLegalization : public HloModulePass { public: - absl::string_view name() const override { return "pad insertion"; } + absl::string_view name() const override { + return "cudnn-conv-padding-legalization"; + } StatusOr Run(HloModule* module) override; @@ -41,4 +43,4 @@ class PadInsertion : public HloModulePass { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_INSERTION_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PADDING_LEGALIZATION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc similarity index 92% rename from tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc index ef292373018295f5100b91c343df274b626c2fa1..01de110aa9361f5813231767ad01e4aac03cfe0a 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h" #include #include @@ -40,7 +40,8 @@ HloInstruction* CreateCudnnConv(const char* call_target, const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const Window& window, const ConvolutionDimensionNumbers& dnums, - int64 feature_group_count) { + int64 feature_group_count, + const OpMetadata& metadata) { HloComputation* computation = lhs->parent(); // This call returns a tuple of (conv_result, scratch_memory), where @@ -59,6 +60,7 @@ HloInstruction* CreateCudnnConv(const char* call_target, const Shape& shape, custom_call->set_window(window); custom_call->set_convolution_dimension_numbers(dnums); custom_call->set_feature_group_count(feature_group_count); + custom_call->set_metadata(metadata); return custom_call; } @@ -188,9 +190,9 @@ std::tuple MatchBackwardFilter( // the amount of high padding the same as the amount of low padding as long // as it is between min_padding_high and max_padding_high. If it is not in // that range, we pick the one that's closest to dim->padding_low() and let - // PadInsertion canonicalize the resultant backward convolution later. - // Picking the closest one minimizes the cost of the kPad instruction to be - // inserted by PadInsertion. + // CudnnConvPaddingLegalization canonicalize the resultant backward + // convolution later. Picking the closest one minimizes the cost of the kPad + // instruction to be inserted by CudnnConvPaddingLegalization. if (dim->padding_low() >= min_padding_high && dim->padding_low() <= max_padding_high) { dim->set_padding_high(dim->padding_low()); @@ -207,7 +209,8 @@ std::tuple MatchBackwardFilter( "negative padding (" << dim->padding_high() << ") on right/bottom of the weight gradients, which is not " - "supported by PadInsertion (b/32744257). Falling back to " + "supported by CudnnConvPaddingLegalization (b/32744257). " + "Falling back to " "unfused convolution for instruction: " << conv->ToString(); return no_match_result; @@ -342,7 +345,8 @@ MatchBackwardInput(HloInstruction* conv) { LOG(ERROR) << "The low padding of the backward convolution would be negative (" << backward_padding_low - << "), which isn't supported by PadInsertion for now (b/32744257)."; + << "), which isn't supported by CudnnConvPaddingLegalization " + "for now (b/32744257)."; return no_match_result; } dim->set_padding_low(backward_padding_low); @@ -371,8 +375,8 @@ MatchBackwardInput(HloInstruction* conv) { dim->set_padding_high(backward_padding_low); } else { // Otherwise, we choose the amount that's closest to backward_padding_low, - // and PadInsertion will later insert kSlice instructions to enforce even - // padding. + // and CudnnConvPaddingLegalization will later insert kSlice + // instructions to enforce even padding. // // For example, consider the backward convolution pattern // @@ -398,9 +402,9 @@ MatchBackwardInput(HloInstruction* conv) { dim->set_padding_high(max_padding_high); } } - // PadInsertion doesn't handle backward input convolution with negative - // padding for now. So fall back to unfused convolution in case of negative - // padding. For example, + // CudnnConvPaddingLegalization doesn't handle backward input + // convolution with negative padding for now. So fall back to unfused + // convolution in case of negative padding. For example, // ABCD = Conv(abc, reverse(xy), padding_high=2) // could be fused to // ABCD = BackwardInputConv(abc, xy, padding_low=1, padding_high=-1) @@ -410,8 +414,8 @@ MatchBackwardInput(HloInstruction* conv) { "negative padding (" << dim->padding_high() << ") on right/bottom of the activations, which is not " - "supported by PadInsertion (b/32744257). Falling back to " - "unfused convolution for instruction: " + "supported by CudnnConvPaddingLegalization (b/32744257). " + "Falling back to unfused convolution for instruction: " << conv->ToString(); return no_match_result; } @@ -497,22 +501,24 @@ StatusOr RunOnInstruction(HloInstruction* conv) { if (match) { return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1), - window, dnums, conv->feature_group_count()); + window, dnums, conv->feature_group_count(), + conv->metadata()); } std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv); if (match) { return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, conv->shape(), conv->mutable_operand(0), rhs, window, dnums, - conv->feature_group_count()); + conv->feature_group_count(), conv->metadata()); } // If all else fails, try a forward convolution. if (CanImplementAsCudnnForwardConv(conv)) { - return CreateCudnnConv( - kCudnnConvForwardCallTarget, conv->shape(), conv->mutable_operand(0), - conv->mutable_operand(1), conv->window(), - conv->convolution_dimension_numbers(), conv->feature_group_count()); + return CreateCudnnConv(kCudnnConvForwardCallTarget, conv->shape(), + conv->mutable_operand(0), conv->mutable_operand(1), + conv->window(), + conv->convolution_dimension_numbers(), + conv->feature_group_count(), conv->metadata()); } return nullptr; @@ -525,6 +531,9 @@ StatusOr RunOnInstruction(HloInstruction* conv) { TF_RETURN_IF_ERROR( custom_call->set_backend_config(GetDefaultBackendConfig())); + VLOG(1) << "Replacing convolution " << conv->ToString() << " with " + << custom_call->ToString(); + // The CustomCall returns a tuple (conv_result, scratch_memory). Extract out // the conv result and replace `conv` with it. TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction( @@ -552,7 +561,7 @@ StatusOr RunOnComputation(HloComputation* computation) { } } // namespace -StatusOr CudnnConvolutionRewriter::Run(HloModule* module) { +StatusOr CudnnConvRewriter::Run(HloModule* module) { bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations()) { TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h similarity index 74% rename from tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h rename to tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h index 8d7c6fdab510407428a115579a90e8cf85e9fad2..d8ec72c27bab8912d0dc2aeead114eb010b87b78 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_REWRITER_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -24,11 +24,9 @@ namespace gpu { // Rewrites plain convolutions, backwards-filter convolutions, and // backwards-input convolutions into CustomCall HLOs that call into cuDNN. -class CudnnConvolutionRewriter : public HloModulePass { +class CudnnConvRewriter : public HloModulePass { public: - absl::string_view name() const override { - return "cudnn-convolution-rewriter"; - } + absl::string_view name() const override { return "cudnn-conv-rewriter"; } StatusOr Run(HloModule* module) override; }; @@ -36,4 +34,4 @@ class CudnnConvolutionRewriter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc similarity index 93% rename from tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc index d237f8930b74d460ad3d4602670a5afb19b496a2..a6980850af370645389f1e10922097f6a16cdee9 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -34,9 +34,9 @@ namespace { namespace op = xla::testing::opcode_matchers; using ::testing::_; -class CudnnConvolutionRewriterTest : public HloVerifiedTestBase { +class CudnnConvRewriterTest : public HloVerifiedTestBase { public: - CudnnConvolutionRewriterTest() + CudnnConvRewriterTest() : HloVerifiedTestBase(/*layout_sensitive=*/true, /*allow_mixed_precision=*/false) { for (int i = 0; i < 2; ++i) { @@ -85,7 +85,7 @@ class CudnnConvolutionRewriterTest : public HloVerifiedTestBase { protected: bool RunPass(HloModule* module) { - return CudnnConvolutionRewriter().Run(module).ValueOrDie(); + return CudnnConvRewriter().Run(module).ValueOrDie(); } // A convolution window with stride 1 and zero padding. The size fields are @@ -95,7 +95,7 @@ class CudnnConvolutionRewriterTest : public HloVerifiedTestBase { ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_; }; -TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) { +TEST_F(CudnnConvRewriterTest, BackwardFilterConvolve) { HloComputation::Builder builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -106,7 +106,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) { Window conv_window = default_conv_window_; conv_window.mutable_dimensions(1)->set_size(2); conv_window.mutable_dimensions(1)->set_window_dilation(2); - builder.AddInstruction(HloInstruction::CreateConvolve( + auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape( activations->shape(), gradients->shape(), /*feature_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_) @@ -114,16 +114,26 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) { activations, gradients, /*feature_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); + OpMetadata metadata; + metadata.set_op_name("foo"); + conv->set_metadata(metadata); + auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); EXPECT_TRUE(RunPass(module)); - EXPECT_THAT(entry_computation->root_instruction(), + ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); + + // Check that metadata was preserved. + const auto& md_after_opt = + entry_computation->root_instruction()->operand(0)->metadata(); + EXPECT_TRUE(protobuf_util::ProtobufEquals(md_after_opt, metadata)) + << md_after_opt.DebugString() << " vs " << metadata.DebugString(); } -TEST_F(CudnnConvolutionRewriterTest, +TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveEquivalentToForwardConvolution) { HloComputation::Builder builder(TestName()); HloInstruction* activations = @@ -152,8 +162,7 @@ TEST_F(CudnnConvolutionRewriterTest, } // Extracted from block35 training. -TEST_F(CudnnConvolutionRewriterTest, - BackwardFilterConvolveWithPaddedActivations) { +TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithPaddedActivations) { auto builder = HloComputation::Builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -183,8 +192,7 @@ TEST_F(CudnnConvolutionRewriterTest, } // Extracted from inception v3 training. -TEST_F(CudnnConvolutionRewriterTest, - BackwardFilterConvolveWithPaddedGradients) { +TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithPaddedGradients) { auto builder = HloComputation::Builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -213,7 +221,7 @@ TEST_F(CudnnConvolutionRewriterTest, op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } -TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) { +TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithUnevenPadding) { auto builder = HloComputation::Builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -242,7 +250,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) { op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } -TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) { +TEST_F(CudnnConvRewriterTest, BackwardInputConvolveEvenPadding) { auto builder = HloComputation::Builder(TestName()); HloInstruction* output = builder.AddInstruction(HloInstruction::CreateParameter( @@ -307,7 +315,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) { // Convolve([abc], [x], base_dilation=2) // = Convolve([abc], Reverse([x]), base_dilation=2) // = BackwardInputConvolve([abc], [x], stride=2) -TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) { +TEST_F(CudnnConvRewriterTest, BackwardInputConvolve1x1Filter) { auto builder = HloComputation::Builder(TestName()); // NHWC dimension order. HloInstruction* output = @@ -341,7 +349,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) { // BackwardInputConvolve([abc], [x], stride=1) is equivalent to // ForwardConvolve([abc], [x], stride=1). No need to fold it into backward input // convolution. -TEST_F(CudnnConvolutionRewriterTest, +TEST_F(CudnnConvRewriterTest, BackwardInputConvolve1x1FilterEquivalentToForwardConvolve) { auto builder = HloComputation::Builder(TestName()); // NHWC dimension order. @@ -385,8 +393,7 @@ TEST_F(CudnnConvolutionRewriterTest, // 20x10x10x192 // // Gradients are padded unevenly. -TEST_F(CudnnConvolutionRewriterTest, - BackwardInputConvolveUnevenPaddingOnGradients) { +TEST_F(CudnnConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) { auto builder = HloComputation::Builder(TestName()); HloInstruction* output = builder.AddInstruction(HloInstruction::CreateParameter( @@ -436,7 +443,7 @@ TEST_F(CudnnConvolutionRewriterTest, // Similar to BackwardInputConvolveUnevenPadding, but the low padding of the // gradients exceeds kernel_size - 1. Therefore, this pattern cannot be fused. -TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { +TEST_F(CudnnConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { auto builder = HloComputation::Builder(TestName()); HloInstruction* output = builder.AddInstruction(HloInstruction::CreateParameter( @@ -488,9 +495,8 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { // padding_low=2, padding_high=1, base_dilation=2) // // We should fuse BC even though padding on activations is uneven, because -// PadInsertion will canonicalize the fusion HLO. -TEST_F(CudnnConvolutionRewriterTest, - BackwardInputConvolveUnevenPaddingOnActivations) { +// CudnnConvPaddingLegalization will canonicalize the fusion HLO. +TEST_F(CudnnConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) { auto builder = HloComputation::Builder(TestName()); // The gradients are in NCHW layout. HloInstruction* output = @@ -543,9 +549,10 @@ TEST_F(CudnnConvolutionRewriterTest, // BC = BackwardInput(FC) does: // [4] = conv([3], reverse([2]), padding_high=2) // -// We currently don't fuse BC because PadInsertion doesn't support negative -// padding on the gradients of backward convolution (b/32744257). -TEST_F(CudnnConvolutionRewriterTest, +// We currently don't fuse BC because CudnnConvPaddingLegalization +// doesn't support negative padding on the gradients of backward convolution +// (b/32744257). +TEST_F(CudnnConvRewriterTest, BackwardInputConvolveNegativePaddingHighOnActivations) { auto builder = HloComputation::Builder(TestName()); // The gradients are in NCHW layout. @@ -586,7 +593,7 @@ TEST_F(CudnnConvolutionRewriterTest, // Check that we will materialize a reversed version of a constant in order to // pattern-match a backwards input convolution. -TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveConstantFilter) { +TEST_F(CudnnConvRewriterTest, BackwardInputConvolveConstantFilter) { Array4D constant_arr(4, 4, 2, 2); constant_arr.FillIota(0); string constant_str = diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc similarity index 78% rename from tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc index 89dd1bb272663ac1f6eecbaae070d201d38e44c8..0b4fdf71623e1597168c6873a0d2b60176e518ce 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" @@ -110,10 +110,10 @@ class ScratchBufAllocator : public se::ScratchAllocator { }; template -Status RunCudnnConvolutionImpl(CudnnConvParams params, - se::ScratchAllocator* scratch_allocator, - se::Stream* stream, - se::dnn::ProfileResult* profile_result) { +Status RunCudnnConvImpl(CudnnConvParams params, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, + se::dnn::ProfileResult* profile_result) { CudnnConvKind kind = params.kind; const Shape& input_shape = *params.input_shape; const Shape& filter_shape = *params.filter_shape; @@ -312,11 +312,12 @@ StatusOr GetCudnnConvParams( TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, conv->backend_config()); - const auto& target = conv->custom_call_target(); + TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(conv)); const auto& lhs_shape = conv->operand(0)->shape(); const auto& rhs_shape = conv->operand(1)->shape(); const auto& conv_result_shape = conv->shape().tuple_shapes(0); + params.kind = kind; params.window = &conv->window(); params.dnums = &conv->convolution_dimension_numbers(); params.feature_group_count = conv->feature_group_count(); @@ -324,77 +325,76 @@ StatusOr GetCudnnConvParams( backend_config.algorithm(), backend_config.tensor_ops_enabled())); params.conv_result_scale = backend_config.conv_result_scale(); - if (target == kCudnnConvForwardCallTarget) { - params.kind = CudnnConvKind::kForward; - params.input_shape = &lhs_shape; - params.filter_shape = &rhs_shape; - params.output_shape = &conv_result_shape; - params.input_buf = operand_buffers[0]; - params.filter_buf = operand_buffers[1]; - params.output_buf = result_buffer; - } else if (target == kCudnnConvBackwardInputCallTarget) { - params.kind = CudnnConvKind::kBackwardInput; - params.input_shape = &conv_result_shape; - params.filter_shape = &rhs_shape; - params.output_shape = &lhs_shape; - params.input_buf = result_buffer; - params.filter_buf = operand_buffers[1]; - params.output_buf = operand_buffers[0]; - } else if (target == kCudnnConvBackwardFilterCallTarget) { - params.kind = CudnnConvKind::kBackwardFilter; - params.input_shape = &lhs_shape; - params.filter_shape = &conv_result_shape; - params.output_shape = &rhs_shape; - params.input_buf = operand_buffers[0]; - params.filter_buf = result_buffer; - params.output_buf = operand_buffers[1]; - } else if (target == kCudnnConvBiasActivationForwardCallTarget) { - params.kind = CudnnConvKind::kForwardActivation; - params.input_shape = &lhs_shape; - params.filter_shape = &rhs_shape; - params.output_shape = &conv_result_shape; - params.fusion.emplace(); - auto& fusion = *params.fusion; - if (backend_config.activation_mode() < - static_cast(se::dnn::ActivationMode::kNumActivationModes)) { - fusion.mode = static_cast( - backend_config.activation_mode()); - } else { - return InternalError("Bad activation mode: %s", - backend_config.ShortDebugString()); - } - fusion.side_input_scale = backend_config.side_input_scale(); - params.input_buf = operand_buffers[0]; - params.filter_buf = operand_buffers[1]; - params.output_buf = result_buffer; - params.fusion->bias_buf = operand_buffers[2]; - if (operand_buffers.size() >= 4) { - params.fusion->side_input_buf = operand_buffers[3]; + switch (kind) { + case CudnnConvKind::kForward: + params.input_shape = &lhs_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &conv_result_shape; + params.input_buf = operand_buffers[0]; + params.filter_buf = operand_buffers[1]; + params.output_buf = result_buffer; + break; + case CudnnConvKind::kBackwardInput: + params.input_shape = &conv_result_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &lhs_shape; + params.input_buf = result_buffer; + params.filter_buf = operand_buffers[1]; + params.output_buf = operand_buffers[0]; + break; + case CudnnConvKind::kBackwardFilter: + params.input_shape = &lhs_shape; + params.filter_shape = &conv_result_shape; + params.output_shape = &rhs_shape; + params.input_buf = operand_buffers[0]; + params.filter_buf = result_buffer; + params.output_buf = operand_buffers[1]; + break; + case CudnnConvKind::kForwardActivation: { + params.kind = CudnnConvKind::kForwardActivation; + params.input_shape = &lhs_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &conv_result_shape; + params.fusion.emplace(); + auto& fusion = *params.fusion; + if (backend_config.activation_mode() < + static_cast(se::dnn::ActivationMode::kNumActivationModes)) { + fusion.mode = static_cast( + backend_config.activation_mode()); + } else { + return InternalError("Bad activation mode: %s", + backend_config.ShortDebugString()); + } + fusion.side_input_scale = backend_config.side_input_scale(); + params.input_buf = operand_buffers[0]; + params.filter_buf = operand_buffers[1]; + params.output_buf = result_buffer; + params.fusion->bias_buf = operand_buffers[2]; + if (operand_buffers.size() >= 4) { + params.fusion->side_input_buf = operand_buffers[3]; + } } - } else { - return InternalError("Unexpected custom call target: %s", target); } return params; } } // anonymous namespace -Status RunCudnnConvolution(const HloCustomCallInstruction* conv, - absl::Span operand_buffers, - se::DeviceMemoryBase result_buffer, - se::DeviceMemoryBase scratch_buf, se::Stream* stream, - se::dnn::ProfileResult* profile_result) { +Status RunCudnnConv(const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer, + se::DeviceMemoryBase scratch_buf, se::Stream* stream, + se::dnn::ProfileResult* profile_result) { ScratchBufAllocator scratch_allocator(scratch_buf); - return RunCudnnConvolution(conv, operand_buffers, result_buffer, - &scratch_allocator, stream, profile_result); + return RunCudnnConv(conv, operand_buffers, result_buffer, &scratch_allocator, + stream, profile_result); } -Status RunCudnnConvolution(const HloCustomCallInstruction* conv, - absl::Span operand_buffers, - se::DeviceMemoryBase result_buffer, - se::ScratchAllocator* scratch_allocator, - se::Stream* stream, - se::dnn::ProfileResult* profile_result) { +Status RunCudnnConv(const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer, + se::ScratchAllocator* scratch_allocator, se::Stream* stream, + se::dnn::ProfileResult* profile_result) { TF_ASSIGN_OR_RETURN(CudnnConvParams params, GetCudnnConvParams(conv, operand_buffers, result_buffer)); @@ -402,14 +402,14 @@ Status RunCudnnConvolution(const HloCustomCallInstruction* conv, conv->shape().tuple_shapes(0).element_type(); switch (output_primitive_type) { case F16: - return RunCudnnConvolutionImpl(params, scratch_allocator, - stream, profile_result); + return RunCudnnConvImpl(params, scratch_allocator, stream, + profile_result); case F32: - return RunCudnnConvolutionImpl(params, scratch_allocator, stream, - profile_result); + return RunCudnnConvImpl(params, scratch_allocator, stream, + profile_result); case F64: - return RunCudnnConvolutionImpl(params, scratch_allocator, stream, - profile_result); + return RunCudnnConvImpl(params, scratch_allocator, stream, + profile_result); default: LOG(FATAL) << ShapeUtil::HumanString(*params.output_shape); } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h similarity index 67% rename from tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h rename to tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h index 61aec1ceccec0f253f9ddaa688d64cacea800cf3..edbc75a94a1238540390b93f0fa5217852c7781f 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_RUNNER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_RUNNER_H_ #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -42,20 +42,19 @@ namespace gpu { // allocator and take note of how much memory is used. The next time you call // the same conv, you can provide an explicitly preallocated scratch buffer of // that size, if you like. -Status RunCudnnConvolution(const HloCustomCallInstruction* conv, - absl::Span operand_buffers, - se::DeviceMemoryBase result_buffer, - se::DeviceMemoryBase scratch_buf, se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); - -Status RunCudnnConvolution(const HloCustomCallInstruction* conv, - absl::Span operand_buffers, - se::DeviceMemoryBase result_buffer, - se::ScratchAllocator* scratch_allocator, - se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); +Status RunCudnnConv(const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer, + se::DeviceMemoryBase scratch_buf, se::Stream* stream, + se::dnn::ProfileResult* profile_result = nullptr); + +Status RunCudnnConv(const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer, + se::ScratchAllocator* scratch_allocator, se::Stream* stream, + se::dnn::ProfileResult* profile_result = nullptr); } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_RUNNER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc similarity index 97% rename from tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc index 3761c19cfcab10e0c6faa17c2d1d535d706ff6c5..cde65ad5745a3c102d029907e0690dc8c34620fd 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" @@ -226,6 +226,7 @@ StatusOr> TryRewriteToCudnnForwardRelu( new_conv->set_window(conv->window()); new_conv->set_convolution_dimension_numbers( conv->convolution_dimension_numbers()); + new_conv->set_metadata(conv->metadata()); TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config, conv->backend_config()); config.set_activation_mode( @@ -234,14 +235,15 @@ StatusOr> TryRewriteToCudnnForwardRelu( config.set_side_input_scale(alpha_side_input); TF_RETURN_IF_ERROR(new_conv->set_backend_config(config)); - VLOG(1) << "Rewriting " << conv->name() << " to " << new_conv->name(); + VLOG(1) << "Replacing convolution " << conv->ToString() << " with " + << new_conv->ToString(); return HloInstruction::CreateGetTupleElement(conv->shape().tuple_shapes(0), new_conv, 0); } } // namespace -StatusOr CudnnFusedConvolutionRewriter::Run(HloModule* module) { +StatusOr CudnnFusedConvRewriter::Run(HloModule* module) { bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations()) { std::vector matches; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h similarity index 77% rename from tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h rename to tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h index bd12aadded9dd9e19bc695ddc11e5529931a306a..613ed8dbdc33dfc3684deb5fd3ee8f5b9ea5fc50 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_ #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -22,7 +22,7 @@ limitations under the License. namespace xla { namespace gpu { -class CudnnFusedConvolutionRewriter : public HloModulePass { +class CudnnFusedConvRewriter : public HloModulePass { public: absl::string_view name() const override { return "cudnn-fused-convolution-rewriter"; @@ -34,4 +34,4 @@ class CudnnFusedConvolutionRewriter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc similarity index 79% rename from tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc index 5632cac1862e21825888d94ab1eee5e1c9fd6800..b7dd07a50c637d514439bb7a8ec799e4cabfee55 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "absl/strings/str_replace.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/test.h" @@ -22,7 +23,10 @@ namespace xla { namespace gpu { namespace { -class CudnnFusedConvolutionRewriterTest : public HloTestBase { +using ::testing::HasSubstr; +using ::testing::Not; + +class CudnnFusedConvRewriterTest : public HloTestBase { protected: string GetOptimizedHlo(absl::string_view hlo_string) { return backend() @@ -39,13 +43,11 @@ class CudnnFusedConvolutionRewriterTest : public HloTestBase { for (absl::string_view type : {"f16", "f32", "f64"}) { const string hlo_with_new_type = absl::StrReplaceAll(hlo_string, {{"TYPE", type}}); - const string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type); - EXPECT_EQ(absl::string_view::npos, - optimized_hlo_string.find("__cudnn$convForward")) - << optimized_hlo_string; - EXPECT_NE(absl::string_view::npos, - optimized_hlo_string.find("__cudnn$convBiasActivationForward")) - << optimized_hlo_string; + string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type); + EXPECT_THAT(optimized_hlo_string, + Not(HasSubstr(kCudnnConvForwardCallTarget))); + EXPECT_THAT(optimized_hlo_string, + HasSubstr(kCudnnConvBiasActivationForwardCallTarget)); EXPECT_TRUE(RunAndCompare(hlo_with_new_type, ErrorSpec{0.01})) << optimized_hlo_string; } @@ -55,18 +57,15 @@ class CudnnFusedConvolutionRewriterTest : public HloTestBase { for (absl::string_view type : {"f16", "f32", "f64"}) { const string hlo_with_new_type = absl::StrReplaceAll(hlo_string, {{"TYPE", type}}); - string optimized_hlo = GetOptimizedHlo(hlo_with_new_type); - EXPECT_NE(absl::string_view::npos, - optimized_hlo.find("__cudnn$convForward")) - << optimized_hlo; - EXPECT_EQ(absl::string_view::npos, - optimized_hlo.find("__cudnn$convBiasActivationForward")) - << optimized_hlo; + string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type); + EXPECT_THAT(optimized_hlo_string, HasSubstr(kCudnnConvForwardCallTarget)); + EXPECT_THAT(optimized_hlo_string, + Not(HasSubstr(kCudnnConvBiasActivationForwardCallTarget))); } } }; -TEST_F(CudnnFusedConvolutionRewriterTest, TestConvOnly) { +TEST_F(CudnnFusedConvRewriterTest, TestConvOnly) { // max(0, conv(x, w)); TestMatchWithAllTypes(R"( HloModule Test @@ -83,7 +82,7 @@ TEST_F(CudnnFusedConvolutionRewriterTest, TestConvOnly) { })"); } -TEST_F(CudnnFusedConvolutionRewriterTest, TestBias) { +TEST_F(CudnnFusedConvRewriterTest, TestBias) { // max(0, conv(x, w) + bias); TestMatchWithAllTypes(R"( HloModule Test @@ -103,7 +102,7 @@ TEST_F(CudnnFusedConvolutionRewriterTest, TestBias) { })"); } -TEST_F(CudnnFusedConvolutionRewriterTest, TestSideInputOnly) { +TEST_F(CudnnFusedConvRewriterTest, TestSideInputOnly) { // max(0, conv(x, w) + side_input); TestMatchWithAllTypes(R"( HloModule Test @@ -122,7 +121,7 @@ TEST_F(CudnnFusedConvolutionRewriterTest, TestSideInputOnly) { })"); } -TEST_F(CudnnFusedConvolutionRewriterTest, TestBiasAndSideInput) { +TEST_F(CudnnFusedConvRewriterTest, TestBiasAndSideInput) { // max(0, conv(x, w) + side_input + bias); TestMatchWithAllTypes(R"( HloModule Test @@ -144,7 +143,7 @@ TEST_F(CudnnFusedConvolutionRewriterTest, TestBiasAndSideInput) { })"); } -TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConv) { +TEST_F(CudnnFusedConvRewriterTest, TestScaledConv) { // max(0, 0.999994934 * conv(x, w)); TestMatchWithAllTypes(R"( HloModule Test @@ -164,7 +163,7 @@ TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConv) { })"); } -TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndSideInput) { +TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndSideInput) { // max(0, conv(x, w) + 0.899994934 * side_input); TestMatchWithAllTypes(R"( HloModule Test @@ -186,7 +185,7 @@ TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndSideInput) { })"); } -TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndScaledSideInput) { +TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndScaledSideInput) { // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input); TestMatchWithAllTypes(R"( HloModule Test @@ -211,8 +210,7 @@ TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndScaledSideInput) { })"); } -TEST_F(CudnnFusedConvolutionRewriterTest, - TestScaledConvAndScaledSideInputWithBias) { +TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndScaledSideInputWithBias) { // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input + bias); TestMatchWithAllTypes(R"( HloModule Test @@ -240,7 +238,7 @@ TEST_F(CudnnFusedConvolutionRewriterTest, })"); } -TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchMaxZeroOnly) { +TEST_F(CudnnFusedConvRewriterTest, TestMatchMaxZeroOnly) { // max(0.1, conv(x, w)) shouldn't match. TestNotMatchWithAllTypes(R"( HloModule Test @@ -257,7 +255,7 @@ TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchMaxZeroOnly) { })"); } -TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchBroadcastedBiasOnly) { +TEST_F(CudnnFusedConvRewriterTest, TestMatchBroadcastedBiasOnly) { // max(0, conv(x, w) + side_input1 + side_input2) shouldn't match. TestNotMatchWithAllTypes(R"( HloModule Test @@ -278,6 +276,35 @@ TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchBroadcastedBiasOnly) { })"); } +TEST_F(CudnnFusedConvRewriterTest, PreservesMetadata) { + const char* kHloString = R"( + HloModule Test + + ENTRY Test { + zero = f32[] constant(0) + zeros = f32[1,32,9,9] broadcast(zero), dimensions={} + + input = f32[1,17,9,9] parameter(0) + filter = f32[3,3,17,32] parameter(1) + + conv = f32[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1, metadata={op_type="foo"} + ROOT relu = f32[1,32,9,9] maximum(zeros, conv) + })"; + + const string optimized_hlo_string = + backend() + .compiler() + ->RunHloPasses(ParseHloString(kHloString, GetModuleConfigForTest()) + .ConsumeValueOrDie(), + backend().default_stream_executor(), + backend().memory_allocator()) + .ConsumeValueOrDie() + ->ToString(); + EXPECT_THAT( + optimized_hlo_string, + ::testing::ContainsRegex(R"(custom-call.*metadata=\{op_type="foo"\})")); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index c1aaa4bf04ddc31edf723c056805ae5aad994e55..6dcdaf1cfe06e446deed847aaf29088a7ed10e13 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -358,13 +358,6 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(0); const Window& window = hlo->window(); - // TODO(b/31410564): Implement dilation for reduce-window. - if (window_util::HasDilation(window)) { - return Unimplemented( - "Dilation for reduce-window not implemented on GPU. " - "See b/31410564."); - } - PrimitiveType operand_element_type = operand->shape().element_type(); llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), @@ -397,9 +390,24 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( for (size_t i = 0; i < index.size(); ++i) { llvm::Value* stridden_index = NSWMul( index[i], index_typed_const(window.dimensions(i).stride())); + input_index[i] = NSWSub( + NSWAdd(stridden_index, + NSWMul(window_index[i], + index_typed_const( + window.dimensions(i).window_dilation()))), + index_typed_const(window.dimensions(i).padding_low())); + + // We need to verify that we are not in the dilated base area. + llvm::Value* dilation_condition = ICmpEQ( + SRem(input_index[i], + index_typed_const(window.dimensions(i).base_dilation())), + index_typed_const(0)); + in_bounds = And(in_bounds, dilation_condition); + + // Apply base dilation to the index. input_index[i] = - NSWSub(NSWAdd(stridden_index, window_index[i]), - index_typed_const(window.dimensions(i).padding_low())); + SDiv(input_index[i], + index_typed_const(window.dimensions(i).base_dilation())); // We must check whether 0 ≤ input_index[i] < bound, as otherwise // we are in the pad and so can skip the computation. This diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index 79c74e7e8bf3a1aa59243b81942d29180bb46e74..e2ab00ce41c9e23e91449f249620d61d0f7736ae 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 31a9f9b1beb81da81a06f6dc8e7c13c105514092..57426327822d95a42f407ed7488f35acfd3623d2 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" @@ -197,7 +198,7 @@ GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) { } module_spec.AddCudaPtxInMemory(ptx().c_str()); - tensorflow::gtl::FlatMap globals; + absl::flat_hash_map globals; se::ModuleHandle module_handle; executor->LoadModule(module_spec, &module_handle); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 38b0f8f15bd28cf2659e4a53b6634e981545716b..0e276282e40fba0ae4881a51dad0c7c9e8d1c081 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -101,7 +101,7 @@ class GpuExecutable : public Executable { const PointsToSet& GetRootPointsToSet() const; using BufferAllocToDeviceMemoryMap = - tensorflow::gtl::FlatMap; + absl::flat_hash_map; // Loads the PTX or CUBIN for this executable into `executor` and resolves the // globals corresponding to constant buffers. Returns a map mapping buffer diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc index 27a4d0b601f3807fe6b94dd6171a44f292921ede..7d01eeb02567d710e9de089c7f29ffcc5f959f9a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc @@ -57,10 +57,13 @@ TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { HloInstruction::CreateParameter(1, sparse_shape, "param1")); builder.AddInstruction(HloInstruction::CreateBinary( sparse_shape, HloOpcode::kAdd, param0, param1)); - auto module = CreateNewModule(); + // Since verifier is reporting sparse layouts as errors, we should + // use a regular HloModule instead of VerifiedHloModule to avoid + // verifier errors being triggered in the destructor. + auto module = HloTestBase::CreateNewModule(); module->AddEntryComputation(builder.Build()); - Status status = checker().Run(module).status(); + Status status = checker().Run(module.get()).status(); ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); EXPECT_THAT(status.error_message(), HasSubstr("GPU backend does not support")); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 74352f26aa9c3a2ca597da21735438df92f863ab..1c0a23fa3eb38961d420aff05e412c3b4d8524e7 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_options.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" @@ -125,14 +124,8 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( DataLayout input; FilterLayout filter; DataLayout output; - if (ConvUseLayoutHeuristic(instr->GetModule()->config())) { - std::tie(input, filter, output) = - HeuristicLayoutAssignment(instr, stream_executor_); - } else { - input = DataLayout::kBatchDepthYX; - filter = FilterLayout::kOutputInputYX; - output = DataLayout::kBatchDepthYX; - } + std::tie(input, filter, output) = + HeuristicLayoutAssignment(instr, stream_executor_); TF_ASSIGN_OR_RETURN( std::tie(*input_shape->mutable_layout(), @@ -215,21 +208,37 @@ Status GpuLayoutAssignment::AddBackendConstraints( constraints->SetOperandLayout(op1_shape, instruction, 1)); TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(output_shape, instruction)); + } else if (instruction->opcode() == HloOpcode::kSort && + ShapeUtil::Rank(instruction->operand(0)->shape()) > 1) { + // Make sure that all the operands and the output(s) have the same layout. + Shape keys_shape = instruction->operand(0)->shape(); + Layout keys_layout = + LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(keys_shape)); + for (int64 i = 0; i < instruction->operand_count(); ++i) { + Shape shape = instruction->operand(i)->shape(); + *shape.mutable_layout() = keys_layout; + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(shape, instruction, i)); + const LogicalBuffer* output_buffer; + if (ShapeUtil::IsArray(instruction->shape())) { + TF_ASSIGN_OR_RETURN( + output_buffer, + constraints->points_to_analysis().GetBufferDefinedAt(instruction, + {})); + } else { + TF_ASSIGN_OR_RETURN( + output_buffer, + constraints->points_to_analysis().GetBufferDefinedAt(instruction, + {i})); + } + TF_RETURN_IF_ERROR( + constraints->SetBufferLayout(keys_layout, *output_buffer)); + } } } return Status::OK(); } -bool GpuLayoutAssignment::CustomCallRequiresMajorFirstLayout( - const HloInstruction* instruction) { - // - Inputs to cudnn batchnorm custom calls don't need the major-first layout - // (i.e. {n, n-1, ...0}) -- we can handle any layout. - // - Inputs to cudnn convolution require custom layouts handled in - // AddBackendConstraints. - return !IsCustomCallToDnnBatchNorm(*instruction) && - !IsCustomCallToDnnConvolution(*instruction); -} - Status GpuLayoutAssignment::PropagateOperandConstraint( const OperandLayoutConstraint& layout_constraint, LayoutConstraints* constraints) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h index e2b96a81d4de1337de2978a9d3c6c38c6e5fd0cd..6a48e55fd2e784f80a50f4565107db177fb43bfc 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h @@ -30,8 +30,11 @@ namespace gpu { class GpuLayoutAssignment : public LayoutAssignment { public: explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout, + std::function + instruction_can_change_layout_func, se::StreamExecutor* stream_executor) - : LayoutAssignment(entry_computation_layout), + : LayoutAssignment(entry_computation_layout, + std::move(instruction_can_change_layout_func)), stream_executor_(stream_executor) {} ~GpuLayoutAssignment() override {} @@ -43,8 +46,6 @@ class GpuLayoutAssignment : public LayoutAssignment { Status PropagateBufferConstraint( const BufferLayoutConstraint& buffer_constraint, LayoutConstraints* constraints) override; - bool CustomCallRequiresMajorFirstLayout( - const HloInstruction* instruction) override; private: Status AddBackendConstraintsToDnnConvCustomCall( diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index fbc8ddf599570b90e93eb463a1fd6c275b73711c..4822b820f4e229336e2b26cfbd0097c8c31a50c8 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -75,7 +75,8 @@ TEST_F(LayoutAssignmentTest, Elementwise) { ShapeLayout(result_shape_with_layout); GpuLayoutAssignment layout_assignment( - &computation_layout, backend().default_stream_executor()); + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); for (const HloInstruction* operand : add->operands()) { @@ -163,7 +164,8 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { } GpuLayoutAssignment layout_assignment( - &computation_layout, backend().default_stream_executor()); + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first operand to batchnorm should have the same layout as the @@ -233,7 +235,8 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { } GpuLayoutAssignment layout_assignment( - &computation_layout, backend().default_stream_executor()); + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first operand to batchnorm should have the same layout as the @@ -314,7 +317,8 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { } GpuLayoutAssignment layout_assignment( - &computation_layout, backend().default_stream_executor()); + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first and fourth operands to the batchnorm call should have the @@ -347,9 +351,11 @@ TEST_F(LayoutAssignmentTest, DotLayout) { ParseHloString(hlo_text)); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); - GpuLayoutAssignment layout_assignment(&computation_layout, - backend().default_stream_executor()); + module->entry_computation()->ComputeProgramShape(), + /*ignore_layouts=*/false); + GpuLayoutAssignment layout_assignment( + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); Shape expected_shape = @@ -359,6 +365,34 @@ TEST_F(LayoutAssignmentTest, DotLayout) { op::ShapeWithLayout(expected_shape))); } +TEST_F(LayoutAssignmentTest, SortLayout) { + const char* hlo_text = R"( + HloModule SortLayout + ENTRY sort { + keys = f32[3,2]{0,1} constant(f32[3,2]{0,1}{{0,1},{0,1},{0,1}}) + values = f32[2,3]{1,0} parameter(0) + transpose = f32[3,2]{1,0} transpose(values), dimensions={1,0} + ROOT sort = (f32[3,2]{1,0}, f32[3,2]{1,0}) sort(keys, transpose), + dimensions={1} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape(), + /*ignore_layouts=*/false); + GpuLayoutAssignment layout_assignment( + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); + EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); + + Shape expected_shape = ShapeUtil::MakeShapeWithLayout(F32, {3, 2}, {1, 0}); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Sort(op::ShapeWithLayout(expected_shape), + op::ShapeWithLayout(expected_shape))); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 4d5d8e99f88149aabfd0a4aeafc7e6724d29418d..1d66787d8927ad818cbc66d19429c1816fc51748 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -46,6 +47,7 @@ bool IsFusible(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kReduce || hlo.opcode() == HloOpcode::kReduceWindow || hlo.opcode() == HloOpcode::kReshape || + hlo.opcode() == HloOpcode::kScatter || hlo.opcode() == HloOpcode::kSlice || hlo.opcode() == HloOpcode::kTranspose; } @@ -125,8 +127,8 @@ bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) { } // Compute the precise number of operands to the new fusion. - tensorflow::gtl::FlatSet operands( - a->operands().begin(), a->operands().end()); + absl::flat_hash_set operands(a->operands().begin(), + a->operands().end()); operands.insert(b->operands().begin(), b->operands().end()); // If there's an edge between `a` and `b`, don't count it: We're fusing that // producer -> consumer relationship. @@ -222,6 +224,11 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } + // Scatter is only supported at the root of a kInput fusion. + if (producer->opcode() == HloOpcode::kScatter) { + return false; + } + // Do not fuse into reduce input fusions if the resulting kernel would suffer // from poor data locality (due to unfriendly input layouts). if (IsInputFusibleReduction(*consumer) && @@ -284,7 +291,8 @@ bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer, HloInstruction::FusionKind GpuInstructionFusion::ChooseKind( const HloInstruction* producer, const HloInstruction* consumer) { - if (IsReductionToVector(*consumer)) { + if (IsReductionToVector(*consumer) || + consumer->opcode() == HloOpcode::kScatter) { return HloInstruction::FusionKind::kInput; } if (producer->opcode() == HloOpcode::kDot || diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 96bfe0c12eb9cd6ef25804d6b34767471616f7e4..fd9b7cee80bdad9a8ed625872ae68bede10200b3 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -709,5 +709,44 @@ TEST_F(InstructionFusionTest, AvoidsLargeFusion) { } } +TEST_F(InstructionFusionTest, FuseIntoScatter) { + auto module = ParseHloString(R"( + HloModule test_module + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY FuseIntoScatter { + p0 = s32[3,3] parameter(0) + operand = s32[3,3] add(p0, p0) + p1 = s32[2] parameter(1) + indices = s32[2] add(p1, p1) + p2 = s32[2,3] parameter(2) + updates = s32[2,3] add(p2, p2) + scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + ROOT add = s32[3,3] add(scatter, scatter) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Fusion(), op::Fusion())); + EXPECT_EQ(root->operand(0)->fusion_kind(), + HloInstruction::FusionKind::kInput); + EXPECT_THAT(root->operand(0)->fused_expression_root(), + op::Scatter(op::Add(), op::Add(), op::Add())); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index a64a616ab1329422d0197f4a7f99ec557a95f8ed..f373d4a8393a047aba599b0fae954e98a740161e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -108,9 +108,9 @@ bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo); // memory used by cudnn. Callers shouldn't inspect scratch_memory, as its value // is not well-defined. // -// CudnnConvolutionRewriter lowers kConvolution HLOs to these custom calls. +// CudnnConvRewriter lowers kConvolution HLOs to these custom calls. // When it does so, it chooses algorithm -1 and 0 bytes of scratch space. Later -// on in the pipeline, CudnnConvolutionAlgorithmChooser chooses an explicit +// on in the pipeline, CudnnConvAlgorithmChooser chooses an explicit // algorithm for each conv and sets the amount of scratch space needed. // // (Representing the scratch memory as an output may seem strange at first, but diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index b7c37bcf3ca910f10d18339dfe7f1d29f2a55c9e..a3821e077ecf6b1dce1e2c8785fe3a59516db2be 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -179,6 +179,21 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( bool is_atomic_integral = element_type == S32 || element_type == U32 || element_type == S64 || element_type == U64; llvm::Value* source = Load(source_address, "source"); + + // kCopy of RHS -> atomic store. + if (root_opcode == HloOpcode::kCopy && + (element_type == F32 || is_atomic_integral) && + computation.root_instruction()->operand(0)->opcode() == + HloOpcode::kParameter && + computation.root_instruction()->operand(0)->parameter_number() == 1) { + llvm::StoreInst* store = Store(source, output_address); + store->setAtomic(llvm::AtomicOrdering::Unordered); + // Derive a minimum alignment from the type. The optimizer can increase it + // later. + store->setAlignment(ShapeUtil::ByteSizeOfPrimitiveType(element_type)); + return true; + } + if (root_opcode == HloOpcode::kAdd) { // NVPTX supports atomicAdd on F32 and integer types. if (element_type == F32) { @@ -480,18 +495,10 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { TF_RET_CHECK(!ShapeUtil::IsScalar(lhs_shape) && !ShapeUtil::IsScalar(rhs_shape)); - // Reduce along the last dimension of the LHS and the second-to-last dimension - // of the RHS. Vectors are a special case where the reduction dimension is 0 - // for both LHS and RHS. This results in a vector dot product producing a - // scalar. - const int64 lhs_reduction_dimension = - ShapeUtil::GetDimensionNumber(lhs_shape, -1); - const int64 rhs_reduction_dimension = - ShapeUtil::Rank(rhs_shape) >= 2 + dnums.lhs_batch_dimensions_size() - ? ShapeUtil::GetDimensionNumber(rhs_shape, -2) - : dnums.lhs_batch_dimensions_size(); - - // Check that the batch dims don't cover the last two dims. + const int64 lhs_reduction_dimension = dnums.lhs_contracting_dimensions(0); + const int64 rhs_reduction_dimension = dnums.rhs_contracting_dimensions(0); + + // Check that the batch dims don't cover the reduction dimensions. for (int64 batch_dim : dnums.lhs_batch_dimensions()) { CHECK_NE(lhs_reduction_dimension, batch_dim); CHECK_NE(rhs_reduction_dimension, batch_dim); @@ -499,7 +506,11 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // Verify the reduction dimension in the two operands are the same size. TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) == - rhs_shape.dimensions(rhs_reduction_dimension)); + rhs_shape.dimensions(rhs_reduction_dimension)) + << "lhs_shape.dimensions(" << lhs_reduction_dimension + << ") = " << lhs_shape.dimensions(lhs_reduction_dimension) + << ", and rhs_shape.dimensions(" << rhs_reduction_dimension + << ") = " << rhs_shape.dimensions(rhs_reduction_dimension); // Create loop nests which loop through the LHS operand dimensions and the RHS // operand dimensions. The reduction dimension of the LHS and RHS are handled diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index c792dd2ddb0faeba076548ba104aa291e0814140..eb8aaaea4f91f552c2f21f104b83924fd604ebfa 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -34,6 +34,7 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" @@ -43,7 +44,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h" -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h" #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h" #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" @@ -493,13 +494,68 @@ Status IrEmitterUnnested::HandleFft(HloInstruction* fft) { Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { HloInstruction* root = fusion->fused_expression_root(); - // HandleFusion specializes reduction from a multi-dimensional array to a 1D - // array. The specialized version requires a initializer thunk that - // initializes the output array to the initial value of the reduce. if (HloInstruction::FusionKind::kInput == fusion->fusion_kind()) { switch (root->opcode()) { + case HloOpcode::kScatter: { + std::vector> thunks; + // The initialization from 'operand' is using different loop bounds, so + // emit it in a separate kernel. Treat it like a loop fusion, writing to + // the output buffer. + { + int unroll_factor = ComputeMaxUnrollFactor(fusion); + thunks.push_back(BuildKernelThunk( + fusion, /*implements_whole_instruction=*/false, unroll_factor)); + + std::vector operand_parameter_arrays; + for (HloInstruction* operand : fusion->operands()) { + operand_parameter_arrays.push_back(GetIrArray(*operand, *fusion)); + } + GpuElementalIrEmitter operand_elemental_emitter( + hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, + GetNestedComputer()); + FusedIrEmitter operand_fused_emitter(operand_parameter_arrays, + &operand_elemental_emitter); + TF_RETURN_IF_ERROR( + root->mutable_operand(0)->Accept(&operand_fused_emitter)); + + TF_RETURN_IF_ERROR(EmitTargetElementLoopInThunk( + *fusion, operand_fused_emitter.GetGenerator(root->operand(0)), + static_cast(thunks.back().get()))); + } + + // Now build the actual scatter, reading and writing to the freshly + // filled output buffer. + { + thunks.push_back( + BuildKernelThunk(fusion, + /*implements_whole_instruction=*/false)); + // Spin up a new fused emitter for the scatter kernel and emit it. + std::vector scatter_parameter_arrays; + for (HloInstruction* operand : fusion->operands()) { + scatter_parameter_arrays.push_back(GetIrArray(*operand, *fusion)); + } + GpuElementalIrEmitter scatter_elemental_emitter( + hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, + GetNestedComputer()); + FusedIrEmitter scatter_fused_emitter(scatter_parameter_arrays, + &scatter_elemental_emitter); + TF_RETURN_IF_ERROR(root->Accept(&scatter_fused_emitter)); + TF_RETURN_IF_ERROR(EmitScatter( + thunks.back().get(), root, + /*scatter_indices_gen=*/ + scatter_fused_emitter.GetGenerator(root->operand(1)), + /*updates_gen=*/ + scatter_fused_emitter.GetGenerator(root->operand(2)))); + } + thunk_sequence_->emplace_back( + absl::make_unique(std::move(thunks), fusion)); + return Status::OK(); + } case HloOpcode::kTuple: case HloOpcode::kReduce: { + // HandleFusion specializes reduction from a multi-dimensional array to + // a 1D array. The specialized version requires a initializer thunk that + // initializes the output array to the initial value of the reduce. if (root->opcode() == HloOpcode::kReduce && ShapeUtil::IsTuple(root->shape())) { // TODO(b/112040122): Support variadic reduce. @@ -1672,6 +1728,14 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { } Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { + // For the root node of the entry computation we can elide writing the tuple + // buffer. We can always figure out the contents of the tuples from buffer + // assignment because we insert copies to ensure non-ambiguous output buffers. + // GpuExecutable never reads the tuple buffer. + if (tuple == + tuple->parent()->parent()->entry_computation()->root_instruction()) { + return Status::OK(); + } bool all_tuple_elements_have_buffer = absl::c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) { return ir_emitter_context_->buffer_assignment() @@ -1958,6 +2022,178 @@ Status IrEmitterUnnested::HandleRng(HloInstruction* rng) { return Status::OK(); } +Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) { + const HloInstruction* operand = scatter->operand(0); + const HloInstruction* scatter_indices = scatter->operand(1); + const HloInstruction* updates = scatter->operand(2); + + std::vector> thunks; + + // Copy the operand into the output if it's not the same buffer already. + auto operand_buffer = GetAllocationSlice(*operand); + auto destination_buffer = GetAllocationSlice(*scatter); + if (operand_buffer != destination_buffer) { + thunks.push_back(absl::make_unique( + /*source_address=*/operand_buffer, + /*destination_buffer=*/destination_buffer, + /*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape()), scatter)); + } + + thunks.push_back( + BuildKernelThunk(scatter, + /*implements_whole_instruction=*/thunks.empty())); + + TF_RETURN_IF_ERROR( + EmitScatter(thunks.back().get(), scatter, + /*scatter_indices_gen=*/ + [=](const IrArray::Index& index) { + return GetIrArray(*scatter_indices, *scatter) + .EmitReadArrayElement(index, &b_, "scatter_index"); + }, + /*updates_gen=*/ + [=](const IrArray::Index& index) { + return GetIrArray(*updates, *scatter) + .EmitReadArrayElement(index, &b_, "update"); + })); + + // Elide the sequential thunk if there's no copy. + if (thunks.size() == 1) { + thunk_sequence_->push_back(std::move(thunks[0])); + } else { + thunk_sequence_->emplace_back( + absl::make_unique(std::move(thunks), scatter)); + } + return Status::OK(); +} + +Status IrEmitterUnnested::EmitScatter( + Thunk* thunk, HloInstruction* scatter, + const llvm_ir::ElementGenerator& scatter_indices_gen, + const llvm_ir::ElementGenerator& updates_gen) { + const HloInstruction* operand = scatter->operand(0); + const HloInstruction* scatter_indices = scatter->operand(1); + const HloInstruction* updates = scatter->operand(2); + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + CHECK(ShapeUtil::Equal(scatter->shape(), operand->shape())); + + auto loop_body_emitter = [&](const IrArray::Index& index) -> Status { + std::vector raw_window_multidim; + std::vector input_scatter_multidim; + std::vector raw_window_bounds; + + // Partition the index into window indices and scatter indices. + for (int64 i = 0, e = index.size(); i != e; ++i) { + // For window indices also remember the window size, this comes in handy + // later. + if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) { + raw_window_multidim.push_back(index[i]); + raw_window_bounds.push_back(updates->shape().dimensions(i)); + } else { + input_scatter_multidim.push_back(index[i]); + } + } + DCHECK_EQ(raw_window_multidim.size(), + dim_numbers.update_window_dims_size()); + + // Apply inserted_window_dims to the window dimensions. + int64 raw_window_multidim_idx = 0; + std::vector input_window_multidim; + std::vector input_window_bounds; + for (int64 i = 0, e = ShapeUtil::Rank(operand->shape()); i != e; ++i) { + if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) { + input_window_bounds.push_back(1); // Trivial dimension. + input_window_multidim.push_back(index.GetConstantWithIndexType(0)); + } else { + input_window_bounds.push_back( + raw_window_bounds[raw_window_multidim_idx]); + input_window_multidim.push_back( + raw_window_multidim[raw_window_multidim_idx]); + ++raw_window_multidim_idx; + } + } + DCHECK_EQ(input_window_multidim.size(), ShapeUtil::Rank(operand->shape())); + + // Insert a 1 dimension at the end if index_vector_dim requests one. + Shape scatter_indices_shape = scatter_indices->shape(); + if (dim_numbers.index_vector_dim() == + ShapeUtil::Rank(scatter_indices_shape)) { + scatter_indices_shape.add_dimensions(1); + scatter_indices_shape.mutable_layout()->add_minor_to_major( + dim_numbers.index_vector_dim()); + } + + // Now load the indices corresponding to the current window from + // scatter_indices. + llvm_ir::IrArray::Index raw_scatter_index_index(input_scatter_multidim, + index.GetType()); + raw_scatter_index_index.InsertAt(dim_numbers.index_vector_dim(), nullptr); + llvm::Value* is_in_bounds = b_.getTrue(); + for (int64 i = 0, e = dim_numbers.scatter_dims_to_operand_dims_size(); + i != e; ++i) { + // Our index is stored along index_vector_dim, insert that into the lookup + // index into scatter_indices. + raw_scatter_index_index[dim_numbers.index_vector_dim()] = + raw_scatter_index_index.GetConstantWithIndexType(i); + + int64 operand_dim = dim_numbers.scatter_dims_to_operand_dims(i); + TF_ASSIGN_OR_RETURN( + llvm::Value* const loaded_scatter_index, + scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape( + scatter_indices_shape, scatter_indices->shape(), &b_))); + // And add the index to our window index. This yields the output index. + llvm::Value* casted_scatter_index = + IntCast(loaded_scatter_index, index.GetType(), + /*isSigned=*/true); + llvm::Value* dim_offset = + Add(input_window_multidim[operand_dim], casted_scatter_index); + input_window_multidim[operand_dim] = dim_offset; + + // Also do the bounds check now. + int64 max_index = operand->shape().dimensions(operand_dim) - + input_window_bounds[operand_dim] + 1; + // is_in_bounds = index >= 0 && index < dim_size-window_size+1 + // --> index u< dim_size-window_size+1 + is_in_bounds = + And(is_in_bounds, ICmpULT(casted_scatter_index, + index.GetConstantWithIndexType(max_index))); + } + + llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse( + is_in_bounds, "scatter.in_bounds", &b_, /*emit_else=*/false); + llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, &b_); + // All done, now just read from the calculated input from the window, and do + // an atomic store to the calculated location in the output. + llvm_ir::IrArray::Index input_window_index(input_window_multidim, + index.GetType()); + HloInstruction* output_hlo = + scatter->IsFused() ? scatter->parent()->FusionInstruction() : scatter; + llvm::Value* output_address = + GetIrArray(*output_hlo, *output_hlo) + .EmitArrayElementAddress(input_window_index, &b_); + llvm::Value* input_address = Alloca(llvm_ir::PrimitiveTypeToIrType( + updates->shape().element_type(), module_)); + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, updates_gen(index)); + Store(input_ir_value, input_address); + return EmitAtomicOperationForNestedComputation( + *scatter->to_apply(), output_address, input_address); + }; + + // Launch a kernel that reads every element in the updates tensor. We could + // also do one kernel per window instead if bounds checks turn out to be a + // bottleneck. + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + updates->shape(), ir_emitter_context_->device_description()); + UpdateLaunchDimensions(launch_dimensions, thunk, + ir_emitter_context_->llvm_module()); + + return ParallelLoopEmitter(loop_body_emitter, updates->shape(), + launch_dimensions, &b_) + .EmitLoop(IrName(scatter), + GetIndexTypeForKernel(scatter, launch_dimensions.launch_bound(), + &b_)); +} + Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { thunk_sequence_->push_back( BuildKernelThunk(select, /*implements_whole_instruction=*/true)); @@ -1966,34 +2202,34 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { std::vector> thunks; - auto keys = sort->operand(0); - auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr; - ShapeIndex keys_shape_index({}); - ShapeIndex values_shape_index({}); - if (values != nullptr) { - keys_shape_index = ShapeIndex({0}); - values_shape_index = ShapeIndex({1}); - } - auto keys_destination = GetAllocationSlice(*sort, keys_shape_index); - auto values_destination = GetAllocationSlice(*sort, values_shape_index); - - if (keys_destination != GetAllocationSlice(*keys)) { - thunks.push_back(absl::make_unique( - /*source_address=*/GetAllocationSlice(*keys), - /*destination_buffer=*/keys_destination, - /*mem_size=*/ShapeUtil::ByteSizeOf(keys->shape()), nullptr)); - } - if (values != nullptr && values_destination != GetAllocationSlice(*values)) { - // TODO(b/26783907): Figure out why we never seem to share buffers for - // key/value sort. - thunks.push_back(absl::make_unique( - /*source_address=*/GetAllocationSlice(*values), - /*destination_buffer=*/values_destination, - /*mem_size=*/ShapeUtil::ByteSizeOf(values->shape()), nullptr)); + Shape keys_shape = sort->operand(0)->shape(); + for (int64 i = 0; i < sort->operand_count(); ++i) { + ShapeIndex shape_index = + sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); + // We assume that the layout of all involved operands and outputs is the + // same. + TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(keys_shape, + sort->operand(i)->shape())); + TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( + keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index))); + + // If possible, we share buffers. If that is not possible, we need to copy + // the values, because the emitter does the sorting in-place. + auto destination_buffer = GetAllocationSlice(*sort, shape_index); + auto source_address = GetAllocationSlice(*sort->operand(i)); + if (destination_buffer != source_address) { + // TODO(b/26783907): Figure out why we never seem to share buffers for + // key/value sort. + thunks.push_back(absl::make_unique( + /*source_address=*/source_address, + /*destination_buffer=*/destination_buffer, + /*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape()), + nullptr)); + } } int64 dimension_to_sort = sort->dimensions(0); - int64 dimension_to_sort_bound = keys->shape().dimensions(dimension_to_sort); + int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound); auto index_type = b_.getInt64Ty(); @@ -2017,7 +2253,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { thunks.push_back( BuildKernelThunk(sort, /*implements_whole_instruction=*/false)); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - keys->shape(), ir_emitter_context_->device_description()); + keys_shape, ir_emitter_context_->device_description()); UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), ir_emitter_context_->llvm_module()); @@ -2028,12 +2264,21 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { xor_mask = llvm::ConstantInt::get(index_type, 1LL << mask); } + IrArray keys_array; + std::vector values_arrays; + values_arrays.reserve(sort->operand_count() - 1); + for (int64 i = 0; i < sort->operand_count(); ++i) { + ShapeIndex shape_index = + sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); + if (i == 0) { + keys_array = GetIrArray(*sort, *sort, shape_index); + } else { + values_arrays.push_back(GetIrArray(*sort, *sort, shape_index)); + } + } TF_RETURN_IF_ERROR(llvm_ir::EmitSortInPlace( - dimension_to_sort, GetIrArray(*sort, *sort, keys_shape_index), - values != nullptr ? absl::make_optional( - GetIrArray(*sort, *sort, values_shape_index)) - : absl::nullopt, - IrName(sort), xor_mask, &b_, &launch_dimensions)); + dimension_to_sort, keys_array, values_arrays, IrName(sort), xor_mask, + &b_, &launch_dimensions)); } } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index bd5db7205155dc6b15ddea069e172bbd8f419996..93f11c069a4cebdf3c79cba17c824eded4f4b1db 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -76,6 +76,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleInfeed(HloInstruction* xla_infeed) override; Status HandleOutfeed(HloInstruction* outfeed) override; Status HandleRng(HloInstruction* random) override; + Status HandleScatter(HloInstruction* scatter) override; Status HandleSelect(HloInstruction* select) override; Status HandleSort(HloInstruction* sort) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; @@ -184,6 +185,14 @@ class IrEmitterUnnested : public IrEmitter { absl::Span> extra_output_gens); + // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in + // the process. `scatter` may be fused, scatter indices are taken from + // `scatter_indices_gen`, updates from`updates_gen`. The output buffer is + // expected to have the operand values in it already. + Status EmitScatter(Thunk* thunk, HloInstruction* scatter, + const llvm_ir::ElementGenerator& scatter_indices_gen, + const llvm_ir::ElementGenerator& updates_gen); + // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel // for the hlo instruction. bool CheckAndEmitHloWithTile021(HloInstruction* hlo); diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index c21f76f6eb1874bfa5a1d296c78ea0e3b9261eca..835924024b7b7de79624a369a69b07d72ac751ab 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -101,7 +101,7 @@ bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, HloInstruction* instr2) { - tensorflow::gtl::FlatSet in_list; + absl::flat_hash_set in_list; for (auto instr : instr1->operands()) { if (!IsProfitableOperand(instr)) { continue; @@ -148,7 +148,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { bool changed = false; RecomputeReachability(); - tensorflow::gtl::FlatSet to_fuse; + absl::flat_hash_set to_fuse; // Keep a list of the instructions to fuse after making all the fusion // decisions. We first aggressively add instructions to potential_fusion_list, // then filter out instructions that will be no longer fusible because of diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 0b3b429710a1a3158ce57a393a09291c95a2ef7a..791d414c915e6f23d84a38ae99dcfa9a59ab6353 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -38,9 +38,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" -#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" @@ -54,8 +56,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h" #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" -#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h" -#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" @@ -75,7 +75,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" -#include "tensorflow/compiler/xla/service/scatter_expander.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" @@ -176,8 +175,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // elimination has to come after that pass. pipeline.AddPass(); - pipeline.AddPass(); - pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }); @@ -204,21 +201,22 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { // Convert convolutions into CustomCalls to cudnn, then canonicalize them - // (PadInsertion). + // (CudnnConvPaddingLegalization). HloPassPipeline pipeline("conv_canonicalization"); pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); - pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); if (IsVoltaOrLater(*stream_exec)) { - pipeline.AddPass(); - // PadForTensorCores leaves behind unnecessary tuple/get-tuple-element - // pairs that TupleSimplifier fixes. + pipeline.AddPass(); + // CudnnConvPadForTensorCores leaves behind unnecessary + // tuple/get-tuple-element pairs that TupleSimplifier fixes. pipeline.AddPass(); } - // CudnnConvolutionRewriter, PadInsertion and PadForTensorCores may add - // instructions which can be simplified by constant folding. + // CudnnConvRewriter, CudnnConvPaddingLegalization and + // CudnnConvPadForTensorCores may add instructions which can be simplified + // by constant folding. pipeline.AddPass(); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } @@ -232,14 +230,17 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // a layout-sensitive verifier! HloPassPipeline pipeline("layout assignment"); pipeline.AddPass( - hlo_module->mutable_entry_computation_layout(), stream_exec); + hlo_module->mutable_entry_computation_layout(), + LayoutAssignment::InstructionCanChangeLayout, stream_exec); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } { HloPassPipeline pipeline("post-layout_assignment"); - pipeline.AddInvariantChecker(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + pipeline.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. @@ -252,7 +253,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // Choose the fastest algorithm for each conv. // // We pick the algorithm before fusion so we can generate better HLO. After - // CudnnConvolutionRewriter, our convolutions are CustomCalls which return a + // CudnnConvRewriter, our convolutions are CustomCalls which return a // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of // scratch: // @@ -270,12 +271,12 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // The new tuple and gte instructions then be simplified away, because // nobody is expected to use the scratch value. // - // However, if we were to run CudnnConvolutionAlgorithmPicker after fusion + // However, if we were to run CudnnConvAlgorithmPicker after fusion // the gte(customcall, 0) would probably already be into a fusion node. We // can't simplify across HloComputation boundaries, so in this case we // wouldn't be able to simplify away the new_tuple bits. - pipeline.AddPass( - stream_exec, device_allocator, compiler); + pipeline.AddPass(stream_exec, device_allocator, + compiler); // Clean up new_tuple described above. pipeline.AddPass(); @@ -285,8 +286,10 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { HloPassFix fusion("fusion"); - fusion.AddInvariantChecker(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + fusion.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); fusion.AddPass(/*may_duplicate=*/false); fusion.AddPass(/*may_duplicate=*/true); fusion.AddPass(); @@ -298,7 +301,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, HloPassPipeline reduce_pipeline("reduce-precision"); reduce_pipeline.AddInvariantChecker( - /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false); + /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); ReducePrecisionInsertion::AddPasses( &reduce_pipeline, hlo_module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); @@ -324,8 +328,10 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // (b/27180329). Therefore, in that case, we set the output to be a copy of // the parameter. HloPassPipeline pipeline("GPU-ir-emit-prepare"); - pipeline.AddInvariantChecker(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + pipeline.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which @@ -400,7 +406,7 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) { "prefers >= 9.2.88). Compilation of XLA kernels below will likely " "fail.\n\nYou do not need to update CUDA; cherry-picking the ptxas " "binary is sufficient."; - } else if ((vmaj < 9 || vmin < 2 || vdot < 88)) { + } else if (std::make_tuple(vmaj, vmin, vdot) < std::make_tuple(9, 2, 88)) { LOG(WARNING) << "*** WARNING *** You are using ptxas " << vmaj << "." << vmin << "." << vdot @@ -819,9 +825,8 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx, } StatusOr>> -NVPTXCompiler::CompileAheadOfTime( - std::vector> module, - const AotCompilationOptions& options) { +NVPTXCompiler::CompileAheadOfTime(std::unique_ptr module_group, + const AotCompilationOptions& options) { return Unimplemented( "not yet implemented: NVPTXCompiler::CompileAheadOfTime"); } diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index 8e97774750344bfc141daa7d752300762c708613..f79ae2990ae7d6e6985b15727a72358289121aa9 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/node_hash_map.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/executable.h" @@ -58,7 +59,7 @@ class NVPTXCompiler : public LLVMCompiler { DeviceMemoryAllocator* device_allocator) override; StatusOr>> - CompileAheadOfTime(std::vector> module, + CompileAheadOfTime(std::unique_ptr module_group, AotCompilationOptions const& options) override; se::Platform::Id PlatformId() const override; @@ -140,10 +141,10 @@ class NVPTXCompiler : public LLVMCompiler { tensorflow::condition_variable compilation_done_cv_; }; - // Don't even think about switching this to FlatMap; iterator stability is - // critical here. - std::unordered_map + // Don't even think about switching this to flat_hash_map; iterator stability + // is critical here. + absl::node_hash_map compilation_cache_ GUARDED_BY(mutex_); TF_DISALLOW_COPY_AND_ASSIGN(NVPTXCompiler); diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index cf9f102d31305da15dabaf6247f23c5ca9a9e054..375f68a15957936151aee068582a714b62694af2 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -62,13 +62,8 @@ LaunchDimensions CalculateLaunchDimensions( // // * = - auto threads_per_core = device_desc.threads_per_core_limit(); - auto blocks_per_core = device_desc.blocks_per_core_limit(); - int64 threads_per_block; - if (threads_per_core != 0 && blocks_per_core != 0) { - threads_per_block = device_desc.threads_per_core_limit() / - device_desc.blocks_per_core_limit(); - } else { + int64 threads_per_block = device_desc.threads_per_block_limit(); + if (threads_per_block == 0) { static std::atomic log_count{0}; if (log_count.fetch_add(1) < 8) { LOG(WARNING) << "Attempting to calculate launch dimensions for GPU " diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.h b/tensorflow/compiler/xla/service/gpu/stream_assignment.h index c2df83aaa4347a9439798acc6cfc2ba0db995232..52d38b6f20e8d61e2d4966ad15a5583a9cd2e945 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_ASSIGNMENT_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_ASSIGNMENT_H_ +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace gpu { @@ -34,7 +34,7 @@ class StreamAssignment { private: int stream_count_ = 1; // At least the main stream. - tensorflow::gtl::FlatMap hlo_to_stream_number_; + absl::flat_hash_map hlo_to_stream_number_; }; // Assigns GPU streams to instructions in `module`. diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index a7255335672a3622d122e9fc5ebfab236a5ba895..ed46f08d5970d479db33a7b9ad416a1480535764 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -211,15 +211,13 @@ tf_cc_test( ) tf_cc_test( - name = "cudnn_fused_convolution_rewriter_test", - srcs = ["cudnn_fused_convolution_rewriter_test.cc"], + name = "gpu_atomic_test", + srcs = ["gpu_atomic_test.cc"], tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", - "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:filecheck", "//tensorflow/core:test", "//tensorflow/core:test_main", - "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6b18c4c63714b4b3c06d7fa85f4a7a75b8e9ae12 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuAtomicTest : public GpuCodegenTest {}; + +TEST_F(GpuAtomicTest, TestStore) { + const char* hlo_string = R"( + HloModule TensorFlowScatterV1 + + update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) + } + + ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + } +)"; + + CompileAndVerifyIr(hlo_string, R"( +CHECK: store atomic{{.*}}unordered, align 4 +)"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 2bd04259c0e8193e6fde415df17a8232c701dec4..9220865867b770eebfb1ada8f31a5d24693a4b8d 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -18,14 +18,16 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/util.h" namespace xla { -using tensorflow::gtl::FlatMap; -using tensorflow::gtl::FlatSet; +using absl::flat_hash_map; +using absl::flat_hash_set; /*static*/ StatusOr HeapSimulator::MinimumMemoryForModule( @@ -56,7 +58,7 @@ StatusOr HeapSimulator::MinimumMemoryForComputation( const HloComputation& computation, const HloInstructionSequence& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap* + const absl::flat_hash_map* memory_by_computation) { TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, @@ -88,7 +90,7 @@ StatusOr HeapSimulator::Run( const HloInstructionSequence& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options, - const tensorflow::gtl::FlatMap* + const absl::flat_hash_map* memory_by_computation) { HeapSimulator heap(std::move(algorithm), size_fn, options, /*schedule=*/nullptr, memory_by_computation); @@ -115,8 +117,10 @@ Status HeapSimulator::RunComputation( // 'used_buffers' is the reverse map - it tracks which buffers were used by an // instruction, so that we can remove the instructions from a buffer's live // set after they are visited. - FlatMap> live_buffers; - FlatMap> used_buffers; + flat_hash_map> + live_buffers; + flat_hash_map> + used_buffers; auto add_user_to_buffer = [this, &live_buffers, &used_buffers]( const HloInstruction* user, const BufferValue* buffer) { @@ -213,7 +217,7 @@ Status HeapSimulator::RunComputation( VLOG(4) << " Removing user " << instruction->name() << " from buffer " << operand_buffer->ToString(); auto it = live_buffers.find(operand_buffer); - FlatSet* live_set = &it->second; + flat_hash_set* live_set = &it->second; live_set->erase(instruction); if (live_set->empty()) { live_buffers.erase(it); @@ -235,7 +239,8 @@ Status HeapSimulator::RunComputation( // that we should assign. // Make sure each buffer get reused at most once. - FlatSet reused_buffers; + flat_hash_set reused_buffers; + int64 alloc_size_by_instruction = 0; for (const BufferValue* buffer : buffers_defined_by_instruction) { if (IgnoreBuffer(buffer)) { continue; @@ -268,14 +273,15 @@ Status HeapSimulator::RunComputation( if (!shared) { VLOG(3) << " Allocating: " << buffer->ToString(); + alloc_size_by_instruction += size_fn_(*buffer); Alloc(buffer, instruction); } } // Account for the memory used by subcomputations when estimating the // current heap size. if (memory_by_computation_ != nullptr) { - algorithm_->AccountForSubcomputationMemory(instruction, - *memory_by_computation_); + algorithm_->AccountForSubcomputationMemory( + instruction, alloc_size_by_instruction, *memory_by_computation_); } // If all computations in the module have been scheduled, we can save memory @@ -323,7 +329,7 @@ Status HeapSimulator::RunComputation( to_free.reserve(live_buffers.size()); for (const auto& buffer_pending : live_buffers) { const BufferValue* buffer = buffer_pending.first; - const FlatSet& pending = buffer_pending.second; + const flat_hash_set& pending = buffer_pending.second; CHECK_EQ(pending.size(), 1) << *buffer; CHECK(*pending.begin() == nullptr) << *buffer; to_free.push_back(buffer); @@ -345,7 +351,7 @@ HeapSimulator::HeapSimulator( std::unique_ptr algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, const HloSchedule* schedule, - const tensorflow::gtl::FlatMap* + const absl::flat_hash_map* memory_by_computation) : no_fragmentation_stats_(absl::make_unique()), algorithm_(std::move(algorithm)), @@ -381,10 +387,8 @@ void HeapSimulator::Alloc(const BufferValue* buffer, allocated_buffers_.insert(buffer); const int64 size = size_fn_(*buffer); - const HloInstruction* instruction_to_calc_aliasing = - memory_by_computation_ == nullptr ? nullptr : instruction; - algorithm_->Alloc(buffer, size, instruction_to_calc_aliasing); - no_fragmentation_stats_->Alloc(buffer, size, instruction_to_calc_aliasing); + algorithm_->Alloc(buffer, size); + no_fragmentation_stats_->Alloc(buffer, size); FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction, nullptr); } @@ -522,21 +526,9 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) { } } -void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size, - const HloInstruction* instruction) { - // The output buffer of while/call/conditional is always aliased with the - // output buffer of the root instruction in the body. Don't double count. - if (instruction == nullptr || - (instruction->opcode() != HloOpcode::kWhile && - instruction->opcode() != HloOpcode::kCall && - instruction->opcode() != HloOpcode::kConditional)) { - Alloc(buffer, size); - } -} - void NoFragmentationStatsHeap::AccountForSubcomputationMemory( - const HloInstruction* instruction, - const tensorflow::gtl::FlatMap& + const HloInstruction* instruction, int64 alloc_size_by_instruction, + const absl::flat_hash_map& memory_by_computation) { // We only count the memory usage of the largest subcomputation, instead of // adding them all, because subcomputations won't execute in parallel. @@ -550,6 +542,14 @@ void NoFragmentationStatsHeap::AccountForSubcomputationMemory( } } } + if (max_subcomputation_bytes > 0 && + (instruction->opcode() == HloOpcode::kWhile || + instruction->opcode() == HloOpcode::kCall || + instruction->opcode() == HloOpcode::kConditional)) { + // The output buffer of while/call/conditional is always aliased with the + // output buffer of the root instruction in the body. Don't double count. + max_subcomputation_bytes -= alloc_size_by_instruction; + } max_heap_size_ = std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes); } diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index 7d6dcc0dc9436ea6bd30ae14ffe226c014f1ca68..dbbf43082f2c1d21f5ef42f53804bf0969903a58 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/buffer_value_containers.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -30,8 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -58,7 +58,7 @@ class HeapSimulator { // Result represents the result of the heap simulation. struct Result { // The assignment of buffers to chunks. - tensorflow::gtl::FlatMap chunk_map; + absl::flat_hash_map chunk_map; // The total size in bytes of the heap, containing all assigned chunks. int64 heap_size = 0; @@ -100,7 +100,7 @@ class HeapSimulator { const HloComputation& computation, const HloInstructionSequence& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap* + const absl::flat_hash_map* memory_by_computation = nullptr); // Run the heap simulation with the given algorithm, assuming the given @@ -130,7 +130,7 @@ class HeapSimulator { const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options = Options(), - const tensorflow::gtl::FlatMap* + const absl::flat_hash_map* memory_by_computation = nullptr); private: @@ -140,7 +140,7 @@ class HeapSimulator { HeapSimulator(std::unique_ptr algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, const HloSchedule* schedule = nullptr, - const tensorflow::gtl::FlatMap* + const absl::flat_hash_map* memory_by_computation = nullptr); ~HeapSimulator(); @@ -172,7 +172,7 @@ class HeapSimulator { // handle subcomputations. It would be good to unify the handling of // subcomputations, but it's not clear how. const HloSchedule* schedule_; - const tensorflow::gtl::FlatMap* + const absl::flat_hash_map* memory_by_computation_; // In addition to Alloc and Free, the heap simulator exposes a concept of @@ -193,12 +193,12 @@ class HeapSimulator { const BufferValue* canonical = nullptr; int64 refcount = 0; }; - tensorflow::gtl::FlatMap> + absl::flat_hash_map> shared_buffers_; // Hold some sets for error-checking the sequence of Alloc and Free calls. - tensorflow::gtl::FlatSet allocated_buffers_; - tensorflow::gtl::FlatSet freed_buffers_; + absl::flat_hash_set allocated_buffers_; + absl::flat_hash_set freed_buffers_; // Debugging information filled in while the heap simulator runs. HeapSimulatorTrace debug_trace_; @@ -218,12 +218,6 @@ class HeapAlgorithm { // Alloc allocates a buffer of 'size' bytes. virtual void Alloc(const BufferValue* buffer, int64 size) = 0; - // NoFragmentationStatsHeap overrides this method. - virtual void Alloc(const BufferValue* buffer, int64 size, - const HloInstruction* instruction) { - Alloc(buffer, size); - } - // Takes memory usage of subcomputations into account when calculating the // memory usage of a computation. Currently, we don't handle buffer aliasing // between computations entirely correctly. We are careful to not double count @@ -235,7 +229,9 @@ class HeapAlgorithm { // analysis, it's not worth making major changes to HeapSimulator now. virtual void AccountForSubcomputationMemory( const HloInstruction* instruction, - const tensorflow::gtl::FlatMap& + // The total number of bytes allocated by instruction. + int64 alloc_size_by_instruction, + const absl::flat_hash_map& memory_by_computation) {} // Free de-allocates a previously allocated buffer. @@ -257,12 +253,9 @@ class NoFragmentationStatsHeap : public HeapAlgorithm { void Alloc(const BufferValue* buffer, int64 size) override; - void Alloc(const BufferValue* buffer, int64 size, - const HloInstruction* instruction) override; - void AccountForSubcomputationMemory( - const HloInstruction* instruction, - const tensorflow::gtl::FlatMap& + const HloInstruction* instruction, int64 alloc_size_by_instruction, + const absl::flat_hash_map& memory_by_computation) override; void Free(const BufferValue* buffer, int64 size) override; @@ -382,8 +375,7 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { // Free time of the buffer. int64 end; }; - tensorflow::gtl::FlatMap - buffer_intervals_; + absl::flat_hash_map buffer_intervals_; }; // A heap algorithm that chooses the best results from other algorithms added to diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 191fbf8194ac65684cd7bfd48a6931d82c702186..e30e7667f3015bc7bfe67c65147a5016332780f7 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_value.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace { @@ -98,6 +98,124 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie()); } +TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) { + // HloModule SubcomputationAccounting + + // %WhileBody (body_param: f32[4]) -> f32[4] { + // %body_param = f32[4]{0} parameter(0) + // %constant.1 = f32[4]{0} constant({1, 1, 1, 1}) + // ROOT %subtract = f32[4]{0} subtract(f32[4]{0} %body_param, f32[4]{0} + // %constant.1) + // } + + // %WhileCond (cond_param: f32[4]) -> pred[] { + // %cond_param = f32[4]{0} parameter(0) + // %slice = f32[1]{0} slice(f32[4]{0} %cond_param), slice={[0:1]} + // %reshape = f32[] reshape(f32[1]{0} %slice) + // %constant = f32[] constant(0) + // ROOT %not-equal-to = pred[] not-equal-to(f32[] %reshape, f32[] %constant) + // } + + // ENTRY %SubcomputationAccounting () -> f32[2,4] { + // %constant.3 = f32[2,4]{1,0} constant(f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, + // 3, 4 } }) %transpose = f32[2,4]{1,0} transpose(f32[2,4]{1,0} + // %constant.3), dimensions={0,1} %constant.2 = f32[4]{0} constant({1, 1, 1, + // 1}) %while = f32[4]{0} while(f32[4]{0} %constant.2), + // condition=%WhileCond, body=%WhileBody %broadcast = f32[2,4]{1,0} + // broadcast(f32[4]{0} %while), dimensions={1} ROOT %add = f32[2,4]{1,0} + // add(f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast) + // } + + auto module = CreateNewVerifiedModule(); + const Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); + const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); + + // reshape(slice(param)) != 0 + // Needs 5 bytes + auto cond_builder = HloComputation::Builder("WhileCond"); + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "cond_param")); + HloInstruction* slice = + cond_builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1}), cond_param, {0}, {1}, {1})); + HloInstruction* reshape = + cond_builder.AddInstruction(HloInstruction::CreateReshape(r0f32, slice)); + HloInstruction* zero = cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction* cond_comparison = + cond_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, reshape, zero)); + auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); + + // param - 1 + // Needs 16 bytes + auto body_builder = HloComputation::Builder("WhileBody"); + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "body_param")); + HloInstruction* one_vector = + body_builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 1, 1, 1}))); + HloInstruction* subtract = + body_builder.AddInstruction(HloInstruction::CreateBinary( + r1f32, HloOpcode::kSubtract, body_param, one_vector)); + auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); + + // transpose(matrix) + bcast(while) + auto builder = HloComputation::Builder(TestName()); + HloInstruction* while_init = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 1, 1, 1}))); + // Creates 16 bytes, ignoring subcomputations + HloInstruction* while_loop = + builder.AddInstruction(HloInstruction::CreateWhile( + r1f32, cond_computation, body_computation, while_init)); + + // Creates 32 bytes and frees 16 + HloInstruction* bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(r2f32, while_loop, {1})); + + HloInstruction* matrix = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2( + {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}}))); + // Creates 32 bytes + HloInstruction* transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(r2f32, matrix, {0, 1})); + + // Creates 32 bytes and frees 64 + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast)); + + auto entry_computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + std::vector cond_vec = {cond_param, slice, reshape, zero, + cond_comparison}; + std::vector while_body_vec = {body_param, one_vector, + subtract}; + std::vector entry_comp_vec = {while_init, while_loop, bcast, + matrix, transpose, add}; + schedule.set_sequence(cond_computation, cond_vec); + schedule.set_sequence(body_computation, while_body_vec); + schedule.set_sequence(entry_computation, entry_comp_vec); + + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }; + absl::flat_hash_map memory_by_computation; + memory_by_computation[cond_computation] = 5; + memory_by_computation[body_computation] = 16; + std::unique_ptr points_to_analysis = + TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); + + // HeapSimulator accounts for subcomputations. The output buffer is aliased, + // so we don't double count. + EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, schedule.sequence(entry_computation), + *points_to_analysis, size_fn, &memory_by_computation) + .ValueOrDie()); +} + const char kAlloc[] = "Alloc"; const char kFree[] = "Free"; const char kFinish[] = "Finish"; @@ -174,7 +292,7 @@ class HeapSimulatorTracker { // Construct the module sequence grouped by computation. HloSchedule schedule(module_.get()); - tensorflow::gtl::FlatMap reverse_position; + absl::flat_hash_map reverse_position; for (int i = 0; i < full_module_sequence.size(); ++i) { const HloInstruction* instruction = full_module_sequence[i]; schedule.GetOrCreateSequence(instruction->parent()) diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index caaca16f7155f15ea2ac79268ba2e708968b6e33..dbab62f847e8ca5e0b46dfd4162a0f4222640252 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 54 +// Next ID: 58 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -180,6 +180,17 @@ message HloInstructionProto { // Collective permute field. repeated SourceTarget source_target_pairs = 52; + + // Sharding for kDomain instructions. + xla.OpSharding domain_entry_sharding = 54; + xla.OpSharding domain_exit_sharding = 55; + + // For custom call this indicates that the layouts are constrained. If + // constrain_layout is true then the 'shape' field must contain a layout, and + // 'operand_shapes_with_layout' must contain a shape with layout for each + // operand. + bool constrain_layout = 56; + repeated Shape operand_shapes_with_layout = 57; } // Serialization of HloComputation. @@ -214,6 +225,32 @@ message HloScheduleProto { map sequences = 1; } +message HloInputOutputAliasProto { + // The following proto describes a pair of aliased an input + // (described by parameter number and a ShapeIndex of the parameter) + // and an output (described by a ShapeIndex of the root + // instruction). For example: + // + // entry = { + // output_shape_index={1}, + // parameter_number=0, + // parameter_shape_index={1, 2}, + // } + // + // This entry indicates that the first paremter's {1, 2} element is + // aliased with the {1} element of the root instruction. + message AliasEntryProto { + // ShapeIndex of the root hlo. + repeated int64 output_shape_index = 1; + // Number of the parameter in entry computation. + int64 parameter_number = 2; + // ShapeIndex of the parameter instruction. + repeated int64 parameter_shape_index = 3; + } + + repeated AliasEntryProto entries = 1; +} + // Serialization of HloModule. message HloModuleProto { string name = 1; @@ -224,14 +261,17 @@ message HloModuleProto { // callees appear before their callers. repeated HloComputationProto computations = 3; - // The program shape (with layout) of the entry computation. - xla.ProgramShape program_shape = 4; + // The host program shape (with layout) of the entry computation. + xla.ProgramShape host_program_shape = 4; // The id of this module. int64 id = 5; // The schedule for this module. HloScheduleProto schedule = 7; + + // Describes alias information between inputs and outputs. + HloInputOutputAliasProto input_output_alias = 8; } // Serialization of LogicalBuffer. diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 0986da65cbd3d550ecfa01212364518aba651d86..cf8e6594cbe5ffd28ca75dd5006e8817f1e8581c 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" @@ -57,8 +59,9 @@ class BufferValueMap { // construction process. using BufferNumber = int64; - explicit BufferValueMap(const HloDataflowAnalysis& dataflow) - : dataflow_(dataflow) { + explicit BufferValueMap(HloModule* module, + const HloDataflowAnalysis& dataflow) + : module_(module), dataflow_(dataflow) { buffers_.reserve(dataflow_.values().size()); value_to_buffer_number_.reserve(dataflow_.values().size()); for (const HloValue* value : dataflow_.values()) { @@ -119,7 +122,7 @@ class BufferValueMap { } // Return a set of all the values in the given buffer. - const tensorflow::gtl::FlatSet& GetValuesInBuffer( + const absl::flat_hash_set& GetValuesInBuffer( BufferNumber buffer_number) const { return buffers_.at(buffer_number); } @@ -142,7 +145,7 @@ class BufferValueMap { // Move the given value into the given buffer. void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) { BufferNumber old_buffer_number = value_to_buffer_number_.at(&value); - tensorflow::gtl::FlatSet& old_value_set = + absl::flat_hash_set& old_value_set = buffers_.at(old_buffer_number); old_value_set.erase(&value); if (old_value_set.empty()) { @@ -169,6 +172,42 @@ class BufferValueMap { return value_to_buffer_number_.at(&value); } + void ComputeInputOutputAliasedBuffers( + const HloValue& value, std::vector* aliased_buffers) { + // Get parameter value from an aliased_input object. + const auto get_parameter_value = + [this](const std::pair& aliased_input) + -> const HloValue& { + int64 param_number = aliased_input.first; + const ShapeIndex& param_index = aliased_input.second; + return dataflow_.GetUniqueValueAt( + module_->entry_computation()->parameter_instruction(param_number), + param_index); + }; + + // If the value shows up in a root instruction, alias it with parameter + // intruction. + for (const HloPosition& pos : value.positions()) { + if (pos.instruction == module_->entry_computation()->root_instruction()) { + ShapeIndex output_index = pos.index; + + auto aliased_input = + module_->input_output_alias_config().GetAliasedParameter( + output_index); + if (aliased_input) { + aliased_buffers->push_back( + GetBufferForValue(get_parameter_value(*aliased_input))); + } + } + } + + // If the value is parameter instruction itself, alias it with itself. + if (value.instruction()->opcode() == HloOpcode::kParameter && + value.instruction()->parent() == module_->entry_computation()) { + aliased_buffers->push_back(GetBufferForValue(value)); + } + } + void ComputeWhileAliasedBuffers(const HloValue& value, std::vector* aliased_buffers) { VLOG(3) << "Compute kWhile aliases"; @@ -276,6 +315,7 @@ class BufferValueMap { VLOG(2) << "Use of value " << value.ToShortString() << ": " << use; } std::vector aliased_buffers; + ComputeInputOutputAliasedBuffers(value, &aliased_buffers); ComputeWhileAliasedBuffers(value, &aliased_buffers); ComputeConditionalAliasedBuffers(value, &aliased_buffers); // Uniquify aliased buffers. @@ -286,17 +326,17 @@ class BufferValueMap { return aliased_buffers; } + HloModule* module_; + // Dataflow analysis used to construct the buffer map. const HloDataflowAnalysis& dataflow_; // A map containing the set of values contained in each buffer. - tensorflow::gtl::FlatMap> + absl::flat_hash_map> buffers_; // A map indicating which buffer each value is contained in. - tensorflow::gtl::FlatMap - value_to_buffer_number_; + absl::flat_hash_map value_to_buffer_number_; // The buffer number of the next buffer to be created. BufferNumber next_buffer_number_ = 0; @@ -352,7 +392,7 @@ bool HloAliasAnalysis::InstructionBuffersAreAmbiguous( bool HloAliasAnalysis::InstructionBuffersAreDistinct( const HloInstruction* instruction) const { - tensorflow::gtl::FlatSet buffers_seen; + absl::flat_hash_set buffers_seen; for (const auto& pair : dataflow_analysis_->GetInstructionValueSet(instruction)) { const HloValueSet& value_set = pair.second; @@ -461,7 +501,7 @@ StatusOr> HloAliasAnalysis::Run( /*bitcast_defines_value=*/false, fusion_can_share_buffer)); - BufferValueMap buffer_map(alias_analysis->dataflow_analysis()); + BufferValueMap buffer_map(module, alias_analysis->dataflow_analysis()); buffer_map.MergeAliasedBuffers(); // Create a vector of HloBuffers, one for each set of values in the diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.h b/tensorflow/compiler/xla/service/hlo_alias_analysis.h index e345804537723f01e9ccb63e7d6ded1bd68f4196..372f99ff01c786a503e9fc2a1ba96fb4abf75b4c 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_buffer.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" @@ -110,7 +111,7 @@ class HloAliasAnalysis { std::unique_ptr dataflow_analysis_; // A map indicating which buffer a value is contained in. - tensorflow::gtl::FlatMap value_to_buffer_; + absl::flat_hash_map value_to_buffer_; // A lazily constructed vector containing all HloBuffers sorted by // HloBuffer::Id. diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 0cd0ab36fcf832af9a71ab5837c94f9b39bc4bf3..5c8d97b2d15e15d15cb8014a7d25b37437ce8aec 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -217,6 +217,181 @@ TEST_F(HloAliasAnalysisTest, NondistinctTuple) { EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } +TEST_F(HloAliasAnalysisTest, ParametersWithAliasing) { + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + + auto negate0 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0)); + auto negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1)); + + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1})); + module_->AddEntryComputation(builder.Build()); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + + // Cannot alias an output twice. + ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0})); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + EXPECT_EQ(analysis.GetUniqueBufferAt(gte0), + analysis.GetUniqueBufferAt(tuple, /*index=*/{0})); + + EXPECT_EQ(analysis.GetUniqueBufferAt(gte1), + analysis.GetUniqueBufferAt(tuple, /*index=*/{1})); +} + +TEST_F(HloAliasAnalysisTest, ParametersWithCrossAliasing) { + // parameter 0 aliased with output 1 and parameter 1 aliased with output 0. + // + // (p0 , p1) + // \ / + // \ / + // alias X + // / \ + // / \ + // (p0 , p1) + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + module_->AddEntryComputation(builder.Build()); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1})); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0})); + + // Cannot alias an output twice. + ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + // Every Ops in this graph are aliased with each other. + EXPECT_EQ(analysis.GetUniqueBufferAt(gte0), + analysis.GetUniqueBufferAt(tuple, /*index=*/{0})); + EXPECT_EQ(analysis.GetUniqueBufferAt(gte0), + analysis.GetUniqueBufferAt(tuple, /*index=*/{1})); + + EXPECT_EQ(analysis.GetUniqueBufferAt(gte1), + analysis.GetUniqueBufferAt(tuple, /*index=*/{0})); + EXPECT_EQ(analysis.GetUniqueBufferAt(gte1), + analysis.GetUniqueBufferAt(tuple, /*index=*/{1})); +} + +TEST_F(HloAliasAnalysisTest, InputOutputAliasingWithWhile) { + // Test a simple single while instruction can be aliased with input and output + // of the computation. + // + // body((F32[], F32[]) %tuple_param): + // %add = Add(%tuple_param{0}, %tuple_param{1}) + // return Tuple(%tuple_param{0}, %add) + // + // condition((F32[], F32[]) %tuple_param): + // return Constant(false) + // + // entry: + // %param1 = param1 + // %while = While(%param1, body, condition) + // %while_1 = GTE(%while, 0) + // %while_2 = GTE(%while, 1) + // %negate_1 = Negate(%while_1) + // %negate_2 = Negate(%while_2) + // return Tuple(negate_1, negate_2) + // + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + // Element 0 passes transparently through the body. + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); + auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); + auto body_tuple = body_builder.AddInstruction( + HloInstruction::CreateTuple({body_element_0, add})); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); + + // Condition computation trivially returns a constant "false". + auto cond_builder = HloComputation::Builder("condition"); + auto cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloComputation* condition = + module_->AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(tuple_shape, condition, body, param)); + auto while_element_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, xla_while, 0)); + auto while_element_2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, xla_while, 1)); + auto negate_1 = builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape_, HloOpcode::kNegate, while_element_1)); + auto negate_2 = builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape_, HloOpcode::kNegate, while_element_2)); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({negate_1, negate_2})); + module_->AddEntryComputation(builder.Build()); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + EXPECT_THAT( + GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})), + UnorderedElementsAre(GetValueDefinedAt(param, {1}), + GetValueDefinedAt(xla_while, /*index=*/{1}), + GetValueDefinedAt(body_param, {1}), + GetValueDefinedAt(cond_param, {1}), + GetValueDefinedAt(add), + GetValueDefinedAt(negate_2))); + + EXPECT_THAT( + analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).ComputePositions(), + UnorderedElementsAre( + HloPosition{param, {1}}, HloPosition{xla_while, {1}}, + HloPosition{while_element_2, {}}, HloPosition{body_param, {1}}, + HloPosition{body_element_1, {}}, HloPosition{add, {}}, + HloPosition{body_tuple, {1}}, HloPosition{tuple, {1}}, + HloPosition{cond_param, {1}}, HloPosition{negate_2, {}})); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); +} + TEST_F(HloAliasAnalysisTest, SingleCall) { // Test a single call of a subcomputation. The subcomputation adds its two // array-shaped parameters. diff --git a/tensorflow/compiler/xla/service/hlo_buffer.cc b/tensorflow/compiler/xla/service/hlo_buffer.cc index 6c11a073b74c61e44dfe81a32261ae78ae7b46fb..9c3aa0e64d119c2560f4955d0bcb492519fa52a2 100644 --- a/tensorflow/compiler/xla/service/hlo_buffer.cc +++ b/tensorflow/compiler/xla/service/hlo_buffer.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/hlo_clone_context.h b/tensorflow/compiler/xla/service/hlo_clone_context.h index 658643b427a9625fac1166151a89cbd669f817d5..24910ca07bf7c991d31875704b5dd918ed04fe6f 100644 --- a/tensorflow/compiler/xla/service/hlo_clone_context.h +++ b/tensorflow/compiler/xla/service/hlo_clone_context.h @@ -18,8 +18,8 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -73,12 +73,12 @@ class HloCloneContext { return FindOrDie(computations_, old_computation); } - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& cloned_instructions() const { return instructions_; } - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& cloned_computations() const { return computations_; } @@ -86,10 +86,8 @@ class HloCloneContext { private: HloModule* module_; string suffix_; - tensorflow::gtl::FlatMap - instructions_; - tensorflow::gtl::FlatMap - computations_; + absl::flat_hash_map instructions_; + absl::flat_hash_map computations_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 0e5920af7a60966ace4ff52662cd23ea3141d477..b0f7cd91ad1db0a59c09cfbfc1885813dc57e01e 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -24,6 +24,8 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" @@ -39,7 +41,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -122,30 +123,6 @@ HloInstruction* HloComputation::AddParameter( return instructions_.back().get(); } -namespace { - -// Returns the new name for a fusion parameter when we change its number. -// -// Fusion parameters are named foo.param_1, bar.param_2, etc. We are -// renumbering the parameters, so replace the final number in the name with -// the updated value. -string RenameFusionParameter(const string& original_name, int64 new_param_no) { - const string param_underscore = ".param_"; - size_t index = original_name.rfind(param_underscore); - if (index == string::npos) { - return original_name; - } - string after_param = original_name.substr(index + param_underscore.size()); - int64 numeric_suffix; - if (absl::SimpleAtoi(after_param, &numeric_suffix)) { - return StrCat(original_name.substr(0, index + param_underscore.size()), - new_param_no); - } - return original_name; -} - -} // namespace - Status HloComputation::RemoveParameter(int64 param_no) { CHECK_GE(param_no, 0); CHECK_LT(param_no, param_instructions_.size()); @@ -158,11 +135,9 @@ Status HloComputation::RemoveParameter(int64 param_no) { while (param_no < param_instructions_.size()) { param_instruction = param_instructions_[param_no]; - string param_name = - RenameFusionParameter(param_instruction->name(), param_no); HloInstruction* new_instr = AddInstructionInternal(HloInstruction::CreateParameter( - param_no, param_instruction->shape(), param_name)); + param_no, param_instruction->shape(), StrCat("param_", param_no))); TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); param_instructions_[param_no] = new_instr; TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); @@ -186,11 +161,9 @@ Status HloComputation::RemoveUnusedParameters() { if (removed > 0) { const int64 param_no = i - removed; - string param_name = - RenameFusionParameter(param_instruction->name(), param_no); - HloInstruction* new_instr = - AddInstructionInternal(HloInstruction::CreateParameter( - param_no, param_instruction->shape(), param_name)); + HloInstruction* new_instr = AddInstructionInternal( + HloInstruction::CreateParameter(param_no, param_instruction->shape(), + StrCat("param_", param_no))); TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); param_instructions_[param_no] = new_instr; TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); @@ -242,7 +215,7 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( if (removed.count(item) != 0 || item->user_count() != 0 || item == root_instruction() || !IsRemovable(item) || - item->HasSideEffect()) { + (item->HasSideEffect() && item != instruction)) { continue; } for (int i = 0; i < item->operand_count(); ++i) { @@ -305,10 +278,9 @@ void HloComputation::set_root_instruction(HloInstruction* new_root_instruction, namespace { // Helper which builds a post order of the HLO call graph. -void ComputeComputationPostOrder( - HloComputation* computation, - tensorflow::gtl::FlatSet* visited, - std::vector* post_order) { +void ComputeComputationPostOrder(HloComputation* computation, + absl::flat_hash_set* visited, + std::vector* post_order) { if (visited->insert(computation).second) { for (auto* instruction : computation->instructions()) { for (HloComputation* called_computation : @@ -325,7 +297,7 @@ void ComputeComputationPostOrder( void HloComputation::ComputeInstructionPostOrder( const HloComputation::ChannelDependencyMap& channel_dependency_map, std::vector* post_order, HloInstruction* root, - tensorflow::gtl::FlatMap* visited) const { + absl::flat_hash_map* visited) const { std::vector dfs_stack; dfs_stack.push_back(root); while (!dfs_stack.empty()) { @@ -422,7 +394,7 @@ std::vector HloComputation::MakeInstructionPostOrder() const { std::vector post_order; post_order.reserve(instruction_count()); std::vector trace_instructions; - tensorflow::gtl::FlatMap visited; + absl::flat_hash_map visited; for (auto& instruction : instructions_) { if (instruction->opcode() == HloOpcode::kTrace) { // Trace instructions aren't handled by the DFS visitor. Add trace @@ -443,7 +415,7 @@ std::vector HloComputation::MakeInstructionPostOrder() const { std::vector HloComputation::MakeEmbeddedComputationsList() const { - tensorflow::gtl::FlatSet visited; + absl::flat_hash_set visited; std::vector post_order; // To avoid special handling of this computation, cast away const of @@ -533,9 +505,9 @@ HloComputationProto HloComputation::ToProto() const { /* static */ StatusOr> HloComputation::CreateFromProto( const HloComputationProto& proto, - const tensorflow::gtl::FlatMap& computation_map) { - tensorflow::gtl::FlatMap instruction_map; - tensorflow::gtl::FlatMap to_proto_id; + const absl::flat_hash_map& computation_map) { + absl::flat_hash_map instruction_map; + absl::flat_hash_map to_proto_id; std::vector> instructions; int64 parameter_count = 0; for (const HloInstructionProto& instruction_proto : proto.instructions()) { @@ -563,6 +535,28 @@ HloComputation::CreateFromProto( return to_proto_id[a.get()] < to_proto_id[b.get()]; }); + TF_RETURN_IF_ERROR([&]() -> Status { + std::vector parameters_seen(parameter_count); + int parameters_seen_count = 0; + for (auto& instruction : instructions) { + if (instruction->opcode() == HloOpcode::kParameter) { + int64 param_no = instruction->parameter_number(); + TF_RET_CHECK(param_no >= 0 && param_no < parameter_count) + << "Invalid parameter number. Expected [0, " << parameter_count + << "), got " << param_no; + TF_RET_CHECK(!parameters_seen[param_no]) + << "Parameter number " << param_no + << " already allocated in this computation"; + parameters_seen[param_no] = true; + parameters_seen_count++; + } + } + TF_RET_CHECK(parameters_seen_count == parameter_count) + << "Not all parameters in range [0, " << parameter_count + << ") were referenced"; + return Status::OK(); + }()); + auto computation = absl::WrapUnique( new HloComputation(proto.name(), parameter_count, &instructions, root, /*fusion_instruction=*/nullptr)); diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 936a53bd7e9ad362d10f06ab807ddb8944fec93e..dec96d11a93cf56d3c40a6bb7882ffb7336aeeb0 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -25,6 +25,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/map_util.h" @@ -40,8 +42,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -128,9 +128,10 @@ class HloComputation { // users. Instruction is deallocated with this call. Status RemoveInstruction(HloInstruction* instruction); - // Remove an instruction from the computation and also transitively any - // operand that has no users post removing an instruction. The instruction - // must have no users. Instruction is deallocated with this call. + // Remove an instruction (including side effecting ones) from the computation + // and also transitively any operand that has no side effect and no users post + // removing an instruction. The instruction must have no users. Instruction is + // deallocated with this call. Status RemoveInstructionAndUnusedOperands(HloInstruction* instruction); // Set the root of the computation to the given instruction. The instruction @@ -188,7 +189,7 @@ class HloComputation { // calls. static StatusOr> CreateFromProto( const HloComputationProto& proto, - const tensorflow::gtl::FlatMap& computation_map); + const absl::flat_hash_map& computation_map); // Gets the instructions in this computation. // @@ -414,14 +415,14 @@ class HloComputation { // cross-replica-sum the union of the dependencies for all participating // instructions. using ChannelDependencyMap = - tensorflow::gtl::FlatMap>; + absl::flat_hash_map>; ChannelDependencyMap ComputeChannelDependencies() const; enum VisitState { kVisiting, kVisited }; void ComputeInstructionPostOrder( const HloComputation::ChannelDependencyMap& channel_dependency_map, std::vector* post_order, HloInstruction* root, - tensorflow::gtl::FlatMap* visited) const; + absl::flat_hash_map* visited) const; string name_; int64 unique_id_; @@ -439,7 +440,7 @@ class HloComputation { // instruction pointer to location in the list for fast lookup. using InstructionList = std::list>; InstructionList instructions_; - tensorflow::gtl::FlatMap + absl::flat_hash_map instruction_iterators_; std::vector param_instructions_; diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index f837816cea78d78bb3d605dd91e81cac39036268..4f898ce61c3f36e83e4b13130a404dbb4a2c36c6 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -76,6 +76,26 @@ StatusOr HloConstantFolding::Run(HloModule* module) { continue; } + // Don't constant fold unless it's a net positive or the output is small. + if (ShapeUtil::IsArray(instruction->shape())) { + int64 elements_in_removed_operands = 0; + for (HloInstruction* operand : instruction->operands()) { + if (operand->user_count() == 1 && + ShapeUtil::IsArray(operand->shape())) { + elements_in_removed_operands += + ShapeUtil::ElementsIn(operand->shape()); + } + } + int64 elements_in_constant = + ShapeUtil::ElementsIn(instruction->shape()); + + static const int64 kMaximumConstantSizeElements = 2 * 1000 * 1000; + if (elements_in_constant > elements_in_removed_operands && + elements_in_constant > kMaximumConstantSizeElements) { + continue; + } + } + Literal result; // Currently we skip unimplemented operations. // TODO(b/35975797): Fold constant computations for more operations. @@ -84,6 +104,7 @@ StatusOr HloConstantFolding::Run(HloModule* module) { << instruction->ToString(); continue; } + VLOG(4) << "Constant folded: " << instruction->ToString(); TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( instruction, HloInstruction::CreateConstant(std::move(result)))); diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 3e0def5d26a0033d954a776c1c32d6c35acfb505..e45f905f7152c37a9ab2b41d407310671310c2a3 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -242,5 +242,25 @@ TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) { EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reduce()); } +const char* const kConstantFoldLargePad = R"( + HloModule ConstantFoldLargePad + + ENTRY r { + a = f32[1,1,1] constant(f32[1,1,1]{{{7}}}) + b = f32[] constant(42) + ROOT pad = f32[2048,2048,128] pad(a, b), padding=1024_1023x1024_1023x64_63 + })"; + +TEST_F(HloConstantFoldingTest, DoesNotFoldLargePad) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kConstantFoldLargePad)); + HloConstantFolding const_folder; + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + EXPECT_FALSE(result); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Pad(op::Constant(), op::Constant())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index a502fff9a0f1e40065746f2193bf76b1adefdb31..23ab4cda93fc5d6979308bdf9a87f0a16d465154 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -664,6 +664,11 @@ Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) { } Status HloCostAnalysis::HandleGather(const HloInstruction* gather) { + // Gather doesn't read the whole input buffer, it's equivalent to a copy the + // size of the output shape and a read of the gather indices. + current_properties_[kBytesAccessedKey] = + GetShapeSize(gather->shape()) * 2 + + GetShapeSize(gather->operand(1)->shape()); // Gather does not issue any flops. return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index d76ce9ecbca67ae3bc3db4ee2452f30ccec5b88b..802cdfc9e454cf05db18fad9bc7f44fdc146a92e 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -556,5 +556,30 @@ TEST_F(HloCostAnalysisTest, DynamicUpdateSlice) { EXPECT_EQ(analysis.bytes_accessed(), 8); } +TEST_F(HloCostAnalysisTest, Gather) { + // Test the analysis on a gather. + XlaBuilder builder("gather"); + Shape operand_shape = ShapeUtil::MakeShape(S32, {3, 3}); + Shape indices_shape = ShapeUtil::MakeShape(S32, {2}); + + auto operand = Parameter(&builder, 0, operand_shape, "operand"); + auto indices = Parameter(&builder, 1, indices_shape, "indices"); + GatherDimensionNumbers dim_numbers; + dim_numbers.add_offset_dims(1); + dim_numbers.add_collapsed_slice_dims(0); + dim_numbers.add_start_index_map(0); + dim_numbers.set_index_vector_dim(1); + Gather(operand, indices, dim_numbers, {1, 3}); + + auto hlo_module = BuildHloGraph(&builder); + + // Run HLO cost analysis. + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.bytes_accessed(), 56); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index b59c9ba3ed7990eb2a35abc83f87b25a1b1e7c60..e602107cbe64320a8e8e740168cb294ec6be9667 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/hash/hash.h" namespace xla { @@ -137,8 +137,8 @@ StatusOr HloCSE::Run(HloModule* module) { // HLO instructions are grouped into equivalency classes by using the // cse_equal predicate defined above. This set holds a representative // instruction for each class. - tensorflow::gtl::FlatSet + absl::flat_hash_set representatives(/*N=*/computation->instruction_count() + 1, &CseHash, cse_equal); for (auto instruction : computation->MakeInstructionPostOrder()) { diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 6a63681996bc57f4ef16b2405ffc8ce4f003e783..5dcf6bc985ff18fa6fc1ab5a5692914b4597d065 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" @@ -91,7 +92,7 @@ HloDataflowAnalysis::HloDataflowAnalysis( bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple( const HloInstruction* inst) { - tensorflow::gtl::FlatSet visited; + absl::flat_hash_set visited; absl::InlinedVector stack; stack.push_back(inst); while (!stack.empty()) { @@ -125,7 +126,7 @@ bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction, const HloValue& HloDataflowAnalysis::GetValueDefinedAt( const HloInstruction* instruction, const ShapeIndex& index) const { - CHECK(ValueIsDefinedAt(instruction, index)); + CHECK(ValueIsDefinedAt(instruction, index)) << instruction->ToString(); return GetUniqueValueAt(instruction, index); } @@ -159,8 +160,8 @@ void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) { void HloDataflowAnalysis::DeleteMarkedValues() { #ifndef NDEBUG // Verify that no marked-for-deletion values are in any of the value sets. - tensorflow::gtl::FlatSet id_set(value_ids_to_delete_.begin(), - value_ids_to_delete_.end()); + absl::flat_hash_set id_set(value_ids_to_delete_.begin(), + value_ids_to_delete_.end()); for (const auto& pair : value_sets_) { const HloInstruction* instruction = pair.first; const InstructionValueSet& instruction_value_set = pair.second; @@ -355,23 +356,6 @@ bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) { return false; } -bool HloDataflowAnalysis::UpdateSliceValueSet(HloInstruction* slice) { - CHECK_EQ(slice->opcode(), HloOpcode::kSlice); - if (!slice->IsInPlaceSlice()) { - return false; - } - // If this slice is lowered to an in-place version, then it forwards the - // operand value to the output. - const InstructionValueSet& operand_set = - GetInstructionValueSet(slice->operand(0)); - InstructionValueSet& slice_set = GetInstructionValueSet(slice); - if (operand_set != slice_set) { - slice_set = operand_set; - return true; - } - return false; -} - bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) { CHECK_EQ(send->opcode(), HloOpcode::kSend); bool changed = false; @@ -640,8 +624,6 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( switch (instruction->opcode()) { case HloOpcode::kBitcast: return UpdateBitcastValueSet(instruction); - case HloOpcode::kSlice: - return UpdateSliceValueSet(instruction); case HloOpcode::kDomain: return UpdateDomainValueSet(instruction); case HloOpcode::kCopy: @@ -673,7 +655,7 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( void HloDataflowAnalysis::Propagate() { std::queue worklist; - tensorflow::gtl::FlatSet workset; + absl::flat_hash_set workset; auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) { if (workset.insert(instruction).second) { worklist.push(instruction); @@ -813,11 +795,6 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { define_all_values(); } break; - case HloOpcode::kSlice: - if (!instruction->IsInPlaceSlice()) { - define_all_values(); - } - break; case HloOpcode::kWhile: case HloOpcode::kCall: case HloOpcode::kConditional: @@ -1071,6 +1048,7 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( } if (user->opcode() == HloOpcode::kDynamicUpdateSlice || + user->opcode() == HloOpcode::kScatter || user->opcode() == HloOpcode::kWhile) { // We eliminated other users in BufferLiveness::live_range_strictly_before, // so here we just need to check that the use is at operand index 0. diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index e62c1c2ac81981e1f44f4c7e1479107979576e32..abac398c04fc4c418d8814a0097db4434bc1cd9c 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -182,7 +182,6 @@ class HloDataflowAnalysis { // Updates the value set for a particular instruction type. Returns whether // the instruction value set changed. bool UpdateBitcastValueSet(HloInstruction* bitcast); - bool UpdateSliceValueSet(HloInstruction* slice); bool UpdateCallValueSet(HloInstruction* call); bool UpdateConditionalValueSet(HloInstruction* conditional); bool UpdateCopyValueSet(HloInstruction* copy); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 510d6360a1cf94ef06d2ed919a57c7a825886834..909853106d57d181e85e3e4134b4039be2b176f5 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -2283,6 +2283,44 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { dataflow_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); } +TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) { + const char* hlo_text = R"( + HloModule TensorFlowScatterV1 + + update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) + } + + ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text)); + computation_ = module_->entry_computation(); + RunAnalysis(); + + HloInstruction* operand_param = computation_->parameter_instruction(0); + HloInstruction* indices_param = computation_->parameter_instruction(1); + HloInstruction* updates_param = computation_->parameter_instruction(2); + HloInstruction* scatter = computation_->root_instruction(); + + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser( + operand_param, {}, scatter, {})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser( + indices_param, {}, scatter, {})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser( + updates_param, {}, scatter, {})); +} + TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { auto builder = HloComputation::Builder(TestName()); @@ -2308,7 +2346,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { auto values = builder.AddInstruction( HloInstruction::CreateParameter(1, values_shape, "values")); auto sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, + {values})); BuildModuleAndRunAnalysis(builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index 113fd18eae70f0a581e2ab3e44544c47fcab3361..c6d02f9f67bb599e496d20fc2acf2e627ed54438 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -40,18 +42,19 @@ namespace xla { return std::move(domain_map); } -bool HloDomainMap::InSameDomain(HloInstruction* instruction1, - HloInstruction* instruction2) const { +bool HloDomainMap::InSameDomain(const HloInstruction* instruction1, + const HloInstruction* instruction2) const { int64 domain_id1 = GetDomainId(instruction1); int64 domain_id2 = GetDomainId(instruction2); return domain_id1 >= 0 && domain_id1 == domain_id2; } -int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const { +int64 HloDomainMap::GetDomainId(const HloInstruction* instruction) const { return FindOrDefault(instruction_to_domain_, instruction, -1); } -int64 HloDomainMap::GetDomainMetadataId(HloInstruction* instruction) const { +int64 HloDomainMap::GetDomainMetadataId( + const HloInstruction* instruction) const { return FindOrDie(domain_metadata_id_, instruction); } @@ -106,8 +109,8 @@ Status HloDomainMap::PopulateDomainMetadataMap() { auto equal = [](const DomainMetadata* a, const DomainMetadata* b) { return a->Matches(*b); }; - tensorflow::gtl::FlatMap + absl::flat_hash_map domain_metadata(1024, hash, equal); for (auto& domain : instruction_domains_) { @@ -198,7 +201,8 @@ StatusOr> HloDomainMap::CreateDomain( return std::move(domain); } -bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const { +bool HloDomainMap::IsDomainInstruction( + const HloInstruction* instruction) const { if (instruction->opcode() != HloOpcode::kDomain) { return false; } @@ -216,7 +220,7 @@ bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const { /* static */ std::vector HloDomainMap::MakeNonDomainInstructions( - const tensorflow::gtl::FlatSet& instruction_set, + const absl::flat_hash_set& instruction_set, const InstructionOrderMap& instructions_order) { std::vector instructions; instructions.reserve(instruction_set.size()); diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h index 56b557d7cea424f63cd4891661ae446133ee5a37..bce7d1aa7cf1822ef1608674e7bf9483c628e4b5 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.h +++ b/tensorflow/compiler/xla/service/hlo_domain_map.h @@ -19,14 +19,14 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -58,27 +58,26 @@ class HloDomainMap { } // Checks whether two instructions are within the same domain. - bool InSameDomain(HloInstruction* instruction1, - HloInstruction* instruction2) const; + bool InSameDomain(const HloInstruction* instruction1, + const HloInstruction* instruction2) const; // Checks whether instruction is a kDomain instruction of the kind we are // currently processing. - bool IsDomainInstruction(HloInstruction* instruction) const; + bool IsDomainInstruction(const HloInstruction* instruction) const; // Retrieves the domain identifier of the instruction, or -1 in case // instruction is not found within any domain. - int64 GetDomainId(HloInstruction* instruction) const; + int64 GetDomainId(const HloInstruction* instruction) const; // Returns the unique id of the domain metadata for the domain the given // instruction belongs to. The given instruction must not be a kDomain // instruction since each domain instruction is associated with 2 domains. - int64 GetDomainMetadataId(HloInstruction* instruction) const; + int64 GetDomainMetadataId(const HloInstruction* instruction) const; private: // Map used for representing instruction ordering, i.e. // order_map[a] < order_map[b] means a must be ordered before b. - using InstructionOrderMap = - tensorflow::gtl::FlatMap; + using InstructionOrderMap = absl::flat_hash_map; HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {} @@ -111,7 +110,7 @@ class HloDomainMap { // Out of an instruction set, returns a vector of all the ones which are not // a kDomain kind. static std::vector MakeNonDomainInstructions( - const tensorflow::gtl::FlatSet& instruction_set, + const absl::flat_hash_set& instruction_set, const InstructionOrderMap& instructions_order); // Populates domain_metadata_id_ that maps each HloInstruction to the unique @@ -120,8 +119,8 @@ class HloDomainMap { string domain_kind_; std::vector> instruction_domains_; - tensorflow::gtl::FlatMap instruction_to_domain_; - tensorflow::gtl::FlatMap domain_metadata_id_; + absl::flat_hash_map instruction_to_domain_; + absl::flat_hash_map domain_metadata_id_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h index 302807f816e4ab626af419023e7740fd6bde795f..d3c83c15ae3be67a64f3dc4bcb0312ae9fbc33e4 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h @@ -20,11 +20,11 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -42,7 +42,7 @@ class DomainMetadata { // operand/user pathways, without crossing a kDomain instruction of a given // kind. The reach_set can contain kDomain instructions of other kinds, if // two domains of different kind intersect each other. - tensorflow::gtl::FlatSet reach_set; + absl::flat_hash_set reach_set; // The same instructions in reach_set, but purged from kDomain instructions // and ordered according to their computation graph post-order, i.e. @@ -55,8 +55,8 @@ class DomainMetadata { // whose dataflow enters the reach set (domain), while the exit_domains // contains the set of kDomain instructions whose dataflow exit the reach // set. - tensorflow::gtl::FlatSet enter_domains; - tensorflow::gtl::FlatSet exit_domains; + absl::flat_hash_set enter_domains; + absl::flat_hash_set exit_domains; }; virtual ~DomainMetadata() = default; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index d7c39b2778d57c1b2e9da0d87d9c2b91bb47e968..c2998883851481b3cda5a3423baa3454018117b2 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/index_util.h" @@ -189,6 +190,11 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations) return Unimplemented( "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE."); }); + typed_visitors_[TOKEN] = + absl::make_unique([](HloInstruction*) { + return Unimplemented( + "HloEvaluatorTypedVisitor: unhandled primitive type: TOKEN."); + }); } template @@ -1228,7 +1234,7 @@ StatusOr EvaluateSortInternal(HloInstruction* sort, TF_RET_CHECK( ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape())) << "Sort keys and values must have the same dimensions"; - TF_RET_CHECK(sort->operand_count() == 2) << "Expected key-value sort"; + TF_RET_CHECK(sort->operand_count() >= 2) << "Expected key-value sort"; // We need to sort an array of keys and an array of values, where the // sorted order of the values is determined by the keys. The simplest(?) // way to do this is to go to an array-of-pairs representation, sort the @@ -1279,7 +1285,9 @@ StatusOr EvaluateSortInternal(HloInstruction* sort, return SafeLess(a.first, b.first); }); std::vector result_keys; - std::vector result_values; + // We use a InlinedVector here because we need to convert it to an + // absl::Span later, and this would not work with std::vector. + absl::InlinedVector result_values; for (const auto& key_value : key_value_vector) { result_keys.push_back(key_value.first); result_values.push_back(key_value.second); @@ -1315,7 +1323,10 @@ template StatusOr EvaluateSortCurried(HloInstruction* sort, const Literal& keys_literal, const Literal& values_literal) { - switch (sort->operand(1)->shape().element_type()) { + switch (values_literal.shape().element_type()) { + case PRED: + return EvaluateSortInternal(sort, keys_literal, + values_literal); case F32: return EvaluateSortInternal(sort, keys_literal, values_literal); @@ -1355,14 +1366,24 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) { if (!ShapeUtil::IsTuple(sort->shape())) { return DefaultAction(sort); } else { - auto result = EvaluateSort(sort, GetEvaluatedLiteralFor(sort->operand(0)), - GetEvaluatedLiteralFor(sort->operand(1))); - if (result.ok()) { - evaluated_[sort] = std::move(result.ValueOrDie()); - return Status::OK(); - } else { - return result.status(); + // This is a really stupid work-around for the fact it's hard to support a + // multi-value sort directly, due to the fact we need to template the + // evaluation function on all of the value types. + std::vector sort_results_backing; + for (int64 i = 0; i < sort->operand_count(); ++i) { + auto result = EvaluateSort(sort, GetEvaluatedLiteralFor(sort->operand(0)), + GetEvaluatedLiteralFor(sort->operand(i))); + if (!result.ok()) { + return result.status(); + } + sort_results_backing.push_back( + std::move(result.ValueOrDie().DecomposeTuple()[1])); } + std::vector sort_results; + absl::c_transform(sort_results_backing, std::back_inserter(sort_results), + [](const Literal& literal) { return &literal; }); + evaluated_[sort] = LiteralUtil::MakeTuple(sort_results); + return Status::OK(); } } @@ -1378,7 +1399,7 @@ Status HloEvaluator::HandleReduce(HloInstruction* reduce) { "unsupported"); } } - return reduce->Visit(typed_visitors_.at(first_element_type).get()); + return reduce->Visit(typed_visitors_[first_element_type].get()); } } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 6c2662ebaeff5ff3ae21b19fac430c3490e22d36..07f8d0aad4af0b07303b4e485b3630cc75bcb519 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/container/node_hash_map.h" #include "absl/memory/memory.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -134,7 +134,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Wraps around instruction handling to infer types before dispatching to // the corresponding typed Visitor. Status DefaultAction(HloInstruction* hlo) override { - return hlo->Visit(typed_visitors_.at(hlo->shape().element_type()).get()); + return hlo->Visit(typed_visitors_[hlo->shape().element_type()].get()); } Status Preprocess(HloInstruction* hlo) override; @@ -210,8 +210,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // post-orderring. // Must be cleared for each evaluation. // Storing Literal in place require the container to have pointer stability so - // we cannot use FlatMap any more. - std::unordered_map evaluated_; + // we cannot use flat_hash_map any more. + absl::node_hash_map evaluated_; private: template @@ -241,12 +241,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { } // Map from a primitive type to its associated (templated) DfsHloVisitor. - // Note: the hash function here is only needed because current gcc std::hash - // does not specialize for enum types. This should however be fixed in the - // future: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=60970#c5 - tensorflow::gtl::FlatMap, - std::hash> - typed_visitors_; + std::unique_ptr typed_visitors_[PrimitiveType_ARRAYSIZE]; // Caches pointers to input literals, assuming they are in post-order. // Literals are not owned by this class, and they must outlive the lifetime of diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index cee11a8a2166f96ae801095b6364921ed05d0000..608a42bb60702aa075daca39535ca1672dcc5467 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -1463,6 +1463,58 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } +TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) { + HloComputation::Builder b(TestName()); + + // arg: + // f32[3,3] { + // { 1, 2, 3 }, + // { 5, 6, 7 }, + // { 9, 10, 11 }, + // } + auto arg_array = absl::make_unique>(3, 3); + arg_array->FillUnique(1.0f); + auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); + + HloInstruction* arg_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); + + auto init_value = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); + + HloComputation::Builder max_computation("max"); + Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto param_lhs = max_computation.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto param_rhs = max_computation.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + max_computation.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs)); + auto max_func = module().AddEmbeddedComputation(max_computation.Build()); + + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_low(0); + dim.set_padding_high(0); + dim.set_window_dilation(2); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + + Shape shape = ShapeUtil::MakeShape(F32, {1, 1}); + b.AddInstruction(HloInstruction::CreateReduceWindow( + shape, arg_instruction, init_value, window, max_func)); + + module().AddEntryComputation(b.Build()); + + Literal result = Evaluate(); + + auto expected = LiteralUtil::CreateR2({{11}}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + TEST_P(HloEvaluatorTest, ReduceWindowAdd) { HloComputation::Builder b(TestName()); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index b2d12c94b848e4fd8ae473fdc0e4a9f5fecf6286..84fbbd3e0c3ddb704b8db601897f3b199dc99626 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -1072,66 +1072,66 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Convolve input feature with kernel. do { + // Find corresponding spatial dimension index for input (lhs). + int64 lhs_linear_spatial_index = 0; + int64 rhs_linear_spatial_index = 0; + for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { + // Spatial dimension number for input (lhs) and output. + const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki); + const int64 output_spatial_dim = dnums.output_spatial_dimensions(ki); + + // Calculate lhs (input) index without taking base dilation into + // account. + const auto& window_dim = window.dimensions(ki); + const int64 undilated_index = + out_index[output_spatial_dim] * window_dim.stride() - + window_dim.padding_low() + + rhs_spatial_index[ki] * window_dim.window_dilation(); + // Skip if the lhs (input) index is to be dilated. As an + // optimization, skip this mod if there's no dilation. + if (window_dim.base_dilation() > 1 && + undilated_index % window_dim.base_dilation() != 0) { + goto cnt; + } + + // Calculate the actual lhs (input) index after dilation. As an + // optimization, skip this integer divide if there's no dilation. + int64 lhs_spatial_index; + if (window_dim.base_dilation() > 1) { + lhs_spatial_index = undilated_index / window_dim.base_dilation(); + } else { + lhs_spatial_index = undilated_index; + } + + // Skip if input index is not in bounds. + if (!(lhs_spatial_index >= 0 && + lhs_spatial_index < lhs_shape.dimensions(input_spatial_dim))) { + goto cnt; + } + + lhs_linear_spatial_index += + lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim]; + rhs_linear_spatial_index += + (window_dim.window_reversal() + ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) + : rhs_spatial_index[ki]) * + rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)]; + } + for (int64 rhs_iz = 0; rhs_iz < input_feature_group_size; ++rhs_iz) { const int64 iz = feature_group_index * input_feature_group_size + rhs_iz; - int64 lhs_linear_index = 0; + int64 lhs_linear_index = lhs_linear_spatial_index; lhs_linear_index += out_index[output_batch_dim] * lhs_dim_multipliers[input_batch_dim]; lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim]; - int64 rhs_linear_index = 0; + int64 rhs_linear_index = rhs_linear_spatial_index; rhs_linear_index += out_index[output_z_dim] * rhs_dim_multipliers[kernel_output_z_dim]; rhs_linear_index += rhs_iz * rhs_dim_multipliers[kernel_input_z_dim]; - // Find corresponding spatial dimension index for input (lhs). - for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { - // Spatial dimension number for input (lhs) and output. - const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki); - const int64 output_spatial_dim = - dnums.output_spatial_dimensions(ki); - - // Calculate lhs (input) index without taking base dilation into - // account. - const auto& window_dim = window.dimensions(ki); - const int64 undilated_index = - out_index[output_spatial_dim] * window_dim.stride() - - window_dim.padding_low() + - rhs_spatial_index[ki] * window_dim.window_dilation(); - // Skip if the lhs (input) index is to be dilated. As an - // optimization, skip this mod if there's no dilation. - if (window_dim.base_dilation() > 1 && - undilated_index % window_dim.base_dilation() != 0) { - goto cnt; - } - - // Calculate the actual lhs (input) index after dilation. As an - // optimization, skip this integer divide if there's no dilation. - int64 lhs_spatial_index; - if (window_dim.base_dilation() > 1) { - lhs_spatial_index = undilated_index / window_dim.base_dilation(); - } else { - lhs_spatial_index = undilated_index; - } - lhs_linear_index += - lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim]; - - // Skip if input index is not in bounds. - if (!(lhs_spatial_index >= 0 && - lhs_spatial_index < - lhs_shape.dimensions(input_spatial_dim))) { - goto cnt; - } - - rhs_linear_index += - (window_dim.window_reversal() - ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) - : rhs_spatial_index[ki]) * - rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)]; - } - result_val += static_cast(lhs_literal_data[lhs_linear_index]) * static_cast(rhs_literal_data[rhs_linear_index]); @@ -2613,8 +2613,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector base_index(rank); bool out_of_bound = false; for (int64 i = 0; i < rank; ++i) { - base_index[i] = window_count_index[i] * window.dimensions(i).stride() + - window_index[i] - window.dimensions(i).padding_low(); + base_index[i] = + window_count_index[i] * window.dimensions(i).stride() + + window_index[i] * window.dimensions(i).window_dilation() - + window.dimensions(i).padding_low(); + // We are not in the base area if the dilation placed us out of bounds. + if (base_index[i] % window.dimensions(i).base_dilation() != 0) { + out_of_bound = true; + break; + } + // Apply the dilation to the base area. + base_index[i] /= window.dimensions(i).base_dilation(); if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) { out_of_bound = true; break; diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index de3d7a167752f0de790585e50874dd6d2904bd37..ce4cad42355ec5881f2ae14f4dd52a0588d51cf7 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -90,8 +90,9 @@ std::unique_ptr CreateHloProfilePrinterData( HloInstructionInfo* instruction_info = computation_info->add_instruction_infos(); instruction_info->set_long_name(hlo->ToString()); - instruction_info->set_short_name( - hlo->ToString(HloPrintOptions().set_compact_operands(true))); + instruction_info->set_short_name(hlo->ToString( + HloPrintOptions().set_compact_operands(true).set_print_operand_names( + false))); instruction_info->set_category(hlo->ToCategory()); instruction_info->set_flop_count(cost_analysis.flop_count(*hlo)); instruction_info->set_transcendental_count( diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc new file mode 100644 index 0000000000000000000000000000000000000000..8128fad07ca0b9c3883ed93c6e1c8e977e990cb4 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc @@ -0,0 +1,182 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { +Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index, + int64 param_number, + const ShapeIndex& param_index) { + TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index)) + << absl::StrCat("Tring to set up alias at ", output_index.ToString(), + " which is an invalid index for shape ", + ShapeUtil::HumanString(alias_.shape())); + // Output can't be aliased with multiple parameters. + TF_RET_CHECK(!alias_.element(output_index)) << absl::StrFormat( + "Trying to set up output alias for param %lld at %s but failed: output " + "index %s is already aliased with param %lld at %s", + param_number, param_index.ToString(), output_index.ToString(), + alias_.element(output_index)->first, + alias_.element(output_index)->second.ToString()); + (*alias_.mutable_element(output_index)) = + std::make_pair(param_number, param_index); + return Status::OK(); +} + +HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const { + HloInputOutputAliasProto result; + alias_.ForEachElement( + [&](const ShapeIndex& index, + const absl::optional>& data) { + if (data) { + HloInputOutputAliasProto::AliasEntryProto entry; + for (int64 i : index) { + entry.add_output_shape_index(i); + } + entry.set_parameter_number(data->first); + for (int64 i : data->second) { + entry.add_parameter_shape_index(i); + } + result.add_entries()->Swap(&entry); + } + }); + return result; +} + +StatusOr HloInputOutputAliasConfig::CreateFromProto( + const Shape& output_shape, const HloInputOutputAliasProto& proto) { + HloInputOutputAliasConfig result(output_shape); + for (const HloInputOutputAliasProto::AliasEntryProto& entry : + proto.entries()) { + ShapeIndex output_index(entry.output_shape_index().begin(), + entry.output_shape_index().end()); + + int64 param_number = entry.parameter_number(); + ShapeIndex param_index(entry.parameter_shape_index().begin(), + entry.parameter_shape_index().end()); + TF_RETURN_IF_ERROR( + result.SetUpAlias(output_index, param_number, param_index)); + } + + return result; +} + +string HloInputOutputAliasConfig::ToString() const { + std::vector pieces; + pieces.push_back("HloInputOutputAliasConfig"); + + ForEachAlias([&](const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index) { + pieces.push_back(absl::StrFormat( + " OutputIndex %s is aliased with parameter %lld at %s:", + output_index.ToString(), param_number, param_index.ToString())); + }); + + return absl::StrJoin(pieces, "\n"); +} + +bool HloInputOutputAliasConfig::ParameterHasAlias( + int64 param_number, const ShapeIndex& param_index) const { + bool output = false; + alias_.ForEachElement( + [&](const xla::ShapeIndex&, + absl::optional> alias) { + if (alias && alias->first == param_number && + alias->second == param_index) { + output = true; + } + }); + return output; +} + +absl::optional HloInputOutputAliasConfig::GetAliasedOutput( + int64 param_number, const ShapeIndex& param_index) const { + absl::optional output; + alias_.ForEachElement( + [&](const xla::ShapeIndex& output_index, + absl::optional> alias) { + if (alias && alias->first == param_number && + alias->second == param_index) { + output = output_index; + } + }); + return output; +} + +absl::optional> +HloInputOutputAliasConfig::GetAliasedParameter( + const ShapeIndex& output_index) const { + CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index)); + return alias_.element(output_index); +} + +void HloInputOutputAliasConfig::ForEachAlias(AliasFn fn) const { + alias_.ForEachElement( + [&](const ShapeIndex& output_index, + absl::optional> aliased) { + if (aliased) { + fn(output_index, aliased->first, aliased->second); + } + }); +} + +Status HloInputOutputAliasConfig::ForEachAliasWithStatus( + AliasFnWithStatus fn) const { + return alias_.ForEachElementWithStatus( + [&](const ShapeIndex& output_index, + absl::optional> aliased) { + if (aliased) { + TF_RETURN_IF_ERROR(fn(output_index, aliased->first, aliased->second)); + } + return Status::OK(); + }); +} + +Status HloInputOutputAliasConfig::Verify(const HloModule& module) const { + std::vector> param_has_seen; + const HloComputation* entry = module.entry_computation(); + for (int64 i = 0; i < entry->num_parameters(); ++i) { + HloInstruction* param = entry->parameter_instruction(i); + param_has_seen.emplace_back(param->shape()); + } + return ForEachAliasWithStatus([&](const ShapeIndex& output_index, + int64 param_number, + const ShapeIndex& param_index) -> Status { + const HloInstruction* root = entry->root_instruction(); + + const Shape& param_shape = + entry->parameter_instruction(param_number)->shape(); + const Shape& output_shape = root->shape(); + TF_RET_CHECK(entry->num_parameters() > param_number); + TF_RET_CHECK(ShapeUtil::IndexIsValid(param_shape, param_index)); + TF_RET_CHECK(ShapeUtil::IndexIsValid(output_shape, output_index)); + + // Check each param_number and param_index pair only show up once. No + // input can be aliased with output buffers. + TF_RET_CHECK(param_has_seen[param_number].element(param_index) == false); + + *(param_has_seen[param_number].mutable_element(param_index)) = true; + + return Status::OK(); + }); +} + +std::ostream& operator<<(std::ostream& out, + const HloInputOutputAliasConfig& config) { + out << config.ToString(); + return out; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h new file mode 100644 index 0000000000000000000000000000000000000000..0fae75842ba28da5dcb59e5952cd60c1d1c5ea68 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_ + +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +class HloModule; + +// This class specifies the alias map from output index to parameter number and +// parameter index in the entry computation. +class HloInputOutputAliasConfig { + public: + HloInputOutputAliasConfig() = default; + + explicit HloInputOutputAliasConfig(Shape shape) : alias_(shape) {} + + virtual ~HloInputOutputAliasConfig() = default; + + // Sets up alias config from `output_index` to `param_index` at + // `param_number`. + Status SetUpAlias(const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index); + + // Returns true if the given parameter is aliased with one of the output + // buffers. + bool ParameterHasAlias(int64 param_number, + const ShapeIndex& param_index) const; + + // (De)Serializes an HloInputOutoutAliasConfig to/from an + // HloInputOutoutAliasProto. + HloInputOutputAliasProto ToProto() const; + + static StatusOr CreateFromProto( + const Shape& output_shape, const HloInputOutputAliasProto& proto); + + // Returns the output index that the given parameter and parameter index is + // aliased with. A nullopt is returned if there is no output that is aliased + // with the parameter number and index. + absl::optional GetAliasedOutput( + int64 param_number, const ShapeIndex& param_index) const; + + // Returns the number of parameter and index of the parameter buffer that the + // given output buffer index is aliased with. A nullopt is returned if there + // is no parameter is aliased with the specific output. + absl::optional> GetAliasedParameter( + const ShapeIndex& output_index) const; + + using AliasFn = + std::function; + + // Iterates through each aliased output and input. + void ForEachAlias(AliasFn fn) const; + + using AliasFnWithStatus = + std::function; + + // Verifies that the given config is valid for the given module. + // Specifically, the config's input and output should be in-bound and size of + // the aliased buffers should match. + Status Verify(const HloModule& module) const; + + Status ForEachAliasWithStatus(AliasFnWithStatus fn) const; + + string ToString() const; + + private: + // A ShapeTree which indicates the list of buffers that's expected to be + // aliased. The key on this shape tree represents the output index. The value + // is a pair of parameter number and index into the buffer. If the value is + // nullopt, it means there is no parameter aliasing for this output. + ShapeTree>> alias_; +}; + +std::ostream& operator<<(std::ostream& out, + const HloInputOutputAliasConfig& config); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_ diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3b61ff04e6d7eeaa5876775fa18a85af82164b3d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc @@ -0,0 +1,184 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { +class HloInputOutputAliasConfigTest : public HloTestBase { + protected: + void expect_aliased(const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index, + const HloInputOutputAliasConfig& config) { + absl::optional aliased_output = + config.GetAliasedOutput(param_number, param_index); + + EXPECT_TRUE(aliased_output); + EXPECT_EQ(aliased_output.value(), output_index); + + absl::optional> aliased_param = + config.GetAliasedParameter(output_index); + + EXPECT_TRUE(aliased_param); + EXPECT_EQ(aliased_param.value(), std::make_pair(param_number, param_index)); + } + + void expect_not_aliased(const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index, + const HloInputOutputAliasConfig& config) { + absl::optional aliased_output = + config.GetAliasedOutput(param_number, param_index); + + EXPECT_FALSE(aliased_output && aliased_output == output_index); + + absl::optional> aliased_param = + config.GetAliasedParameter(output_index); + + EXPECT_FALSE(aliased_param && aliased_param->first == param_number && + aliased_param->second == param_index); + } +}; + +TEST_F(HloInputOutputAliasConfigTest, SimpleAliasing) { + const string module_str = R"( +HloModule TEST + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT root = (f32[], f32[]) tuple(%a, %b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + HloInputOutputAliasConfig config( + module->entry_computation()->root_instruction()->shape()); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1, + /*param_index=*/{})); + + expect_aliased(/*output_index=*/{0}, /*param_number=*/1, + /*param_index=*/{}, config); + + expect_not_aliased(/*output_index=*/{1}, /*param_number=*/1, + /*param_index=*/{}, config); + + expect_not_aliased(/*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{}, config); +} + +TEST_F(HloInputOutputAliasConfigTest, SimpleAliasingWithTupleInput) { + const string module_str = R"( +HloModule TEST + +ENTRY main { + param = (f32[], f32[]) parameter(0) + gte1 = f32[] get-tuple-element(%param), index=0 + gte2 = f32[] get-tuple-element(%param), index=1 + ROOT root = (f32[], f32[]) tuple(%gte1, %gte2) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + HloInputOutputAliasConfig config( + module->entry_computation()->root_instruction()->shape()); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{0})); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{1})); + + expect_aliased(/*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{0}, config); + + expect_aliased(/*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{1}, config); + + expect_not_aliased(/*output_index=*/{1}, /*param_number=*/1, + /*param_index=*/{}, config); + + expect_not_aliased(/*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{}, config); +} + +TEST_F(HloInputOutputAliasConfigTest, InputDoNotAliasTwice) { + const string module_str = R"( +HloModule TEST + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT root = (f32[], f32[]) tuple(%a, %b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + HloInputOutputAliasConfig config( + module->entry_computation()->root_instruction()->shape()); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{})); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{})); + + ASSERT_IS_NOT_OK(config.Verify(*module)); +} + +TEST_F(HloInputOutputAliasConfigTest, OutputDoNotAliasTwice) { + const string module_str = R"( +HloModule TEST + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT root = (f32[], f32[]) tuple(%a, %b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + HloInputOutputAliasConfig config( + module->entry_computation()->root_instruction()->shape()); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{})); + + ASSERT_IS_NOT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1, + /*param_index=*/{})); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 23787dbc8abb300d063e8dd552b2299ff5b36435..f6ed86b41650fd331201814559386ff644092c23 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/ascii.h" @@ -37,14 +39,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/human_readable_json.h" #include "tensorflow/core/platform/logging.h" @@ -59,8 +60,8 @@ using absl::StrJoin; /* static */ StatusOr> HloInstruction::CreateFromProto( const HloInstructionProto& proto, - const tensorflow::gtl::FlatMap& instruction_map, - const tensorflow::gtl::FlatMap& computation_map) { + const absl::flat_hash_map& instruction_map, + const absl::flat_hash_map& computation_map) { TF_RET_CHECK(!proto.opcode().empty()); TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); TF_RET_CHECK(proto.has_shape()); @@ -80,6 +81,20 @@ StatusOr> HloInstruction::CreateFromProto( const auto computations = [&computation_map, &proto](int index) { return computation_map.at(proto.called_computation_ids(index)); }; + + TF_RET_CHECK(std::all_of( + proto.operand_ids().begin(), proto.operand_ids().end(), + [&instruction_map](int64 id) { return instruction_map.contains(id); })) + << proto.name() << " instruction contains invalid operand id(s)"; + + TF_RET_CHECK(std::all_of( + proto.called_computation_ids().begin(), + proto.called_computation_ids().end(), + [&computation_map](int64 id) { return computation_map.contains(id); })) + << proto.name() << " instruction references invalid computation id(s)"; + + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape())); + switch (opcode) { // Ops migrated to subclasses. case HloOpcode::kBatchNormTraining: @@ -180,17 +195,16 @@ StatusOr> HloInstruction::CreateFromProto( } break; case HloOpcode::kSort: { - TF_RET_CHECK(proto.operand_ids_size() == 1 || - proto.operand_ids_size() == 2) - << "Sort instruction should have 1 or 2 operands but has " + TF_RET_CHECK(proto.operand_ids_size() >= 1) + << "Sort instruction should have at least 1 operand but has " << proto.operand_ids_size(); TF_RET_CHECK(proto.dimensions().size() == 1) << "Sort instruction should have 1 dimension"; - HloInstruction* keys = operands(0); - HloInstruction* values = - proto.operand_ids_size() == 2 ? operands(1) : nullptr; - instruction = - CreateSort(proto.shape(), proto.dimensions(0), keys, values); + auto sort_operands = all_operands(); + HloInstruction* keys = sort_operands[0]; + instruction = CreateSort( + proto.shape(), proto.dimensions(0), keys, + absl::Span(sort_operands).subspan(1)); break; } case HloOpcode::kTranspose: @@ -266,7 +280,8 @@ StatusOr> HloInstruction::CreateFromProto( << "Expect 1 called computation for fusion instruction but sees " << proto.called_computation_ids_size(); const int64 fusion_id = proto.called_computation_ids(0); - auto* fused_computation = FindPtrOrNull(computation_map, fusion_id); + auto* fused_computation = + tensorflow::gtl::FindPtrOrNull(computation_map, fusion_id); TF_RET_CHECK(fused_computation != nullptr) << "No fusion computation with id " << fusion_id; instruction = CreateFusion(proto.shape(), fusion_kind, all_operands(), @@ -289,6 +304,9 @@ StatusOr> HloInstruction::CreateFromProto( proto.tuple_index()); break; case HloOpcode::kReducePrecision: + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "ReducePrecision instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateReducePrecision(proto.shape(), operands(0), proto.exponent_bits(), proto.mantissa_bits()); @@ -296,12 +314,18 @@ StatusOr> HloInstruction::CreateFromProto( case HloOpcode::kInfeed: { const Shape& data_shape = ShapeUtil::GetTupleElementShape(proto.shape(), 0); - TF_RET_CHECK(proto.operand_ids_size() == 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Infeed instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateInfeed(data_shape, operands(0), proto.infeed_config()); } break; case HloOpcode::kOutfeed: - TF_RET_CHECK(proto.operand_ids_size() == 2); + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Outfeed instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(proto.outfeed_shape())); instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), operands(1), proto.outfeed_config()); break; @@ -331,6 +355,9 @@ StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kCollectivePermute: { + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "CollectivePermute instruction should have 1 operand but sees " + << proto.operand_ids_size(); std::vector> source_target_pairs( proto.source_target_pairs_size()); for (int i = 0; i < source_target_pairs.size(); i++) { @@ -378,9 +405,22 @@ StatusOr> HloInstruction::CreateFromProto( operands(1), operands(2), computations(1)); break; case HloOpcode::kCustomCall: - instruction = CreateCustomCall(proto.shape(), all_operands(), - proto.custom_call_target(), - proto.custom_call_opaque()); + if (proto.constrain_layout()) { + // A proto RepeatedPtrField cannot be converted to a Span (it is a + // vector of pointers essentially) so create a vector of shapes to pass + // in. + std::vector operand_shapes; + for (const Shape& shape : proto.operand_shapes_with_layout()) { + operand_shapes.push_back(shape); + } + instruction = CreateCustomCall( + proto.shape(), all_operands(), proto.custom_call_target(), + operand_shapes, proto.custom_call_opaque()); + } else { + instruction = CreateCustomCall(proto.shape(), all_operands(), + proto.custom_call_target(), + proto.custom_call_opaque()); + } if (proto.has_window()) { static_cast(instruction.get()) ->set_window(proto.window()); @@ -447,8 +487,8 @@ StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kIota: - TF_RET_CHECK(proto.dimensions_size() <= 1) - << "Iota instruction should have at most 1 dimension but sees " + TF_RET_CHECK(proto.dimensions_size() == 1) + << "Iota instruction should have 1 dimension but sees " << proto.dimensions_size(); instruction = CreateIota(proto.shape(), proto.dimensions(0)); break; @@ -466,31 +506,37 @@ StatusOr> HloInstruction::CreateFromProto( proto.dot_dimension_numbers(), precision_config); break; } - case HloOpcode::kDomain: + case HloOpcode::kDomain: { TF_RET_CHECK(proto.operand_ids_size() == 1) << "Domain instruction should have 1 operands but sees " << proto.operand_ids_size(); + std::shared_ptr entry_hlo_sharding; + std::shared_ptr exit_hlo_sharding; + if (proto.has_domain_entry_sharding()) { + TF_ASSIGN_OR_RETURN( + HloSharding sharding, + HloSharding::FromProto(proto.domain_entry_sharding())); + entry_hlo_sharding = std::make_shared(sharding); + } + if (proto.has_domain_exit_sharding()) { + TF_ASSIGN_OR_RETURN( + HloSharding sharding, + HloSharding::FromProto(proto.domain_exit_sharding())); + exit_hlo_sharding = std::make_shared(sharding); + } instruction = absl::make_unique( - proto.shape(), operands(0), /*operand_side_metadata=*/nullptr, - /*user_side_metadata=*/nullptr); + proto.shape(), operands(0), + absl::make_unique(entry_hlo_sharding), + absl::make_unique(exit_hlo_sharding)); break; + } default: { instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { - TF_RET_CHECK(ContainsKey(instruction_map, operand_id)) - << "No instruction with id " << operand_id; instruction->AppendOperand(instruction_map.at(operand_id)); } - for (const int64 predecessor_id : proto.control_predecessor_ids()) { - TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id)) - << "No instruction with id " << predecessor_id; - TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id) - ->AddControlDependencyTo(instruction.get())); - } if (instruction->opcode() != HloOpcode::kFusion) { for (const int64 computation_id : proto.called_computation_ids()) { - TF_RET_CHECK(ContainsKey(computation_map, computation_id)) - << "No computation with id " << computation_id; instruction->called_computations_.push_back( computation_map.at(computation_id)); } @@ -502,6 +548,13 @@ StatusOr> HloInstruction::CreateFromProto( } } + for (const int64 predecessor_id : proto.control_predecessor_ids()) { + TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id)) + << "No instruction with id " << predecessor_id; + TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id) + ->AddControlDependencyTo(instruction.get())); + } + TF_RET_CHECK(!proto.name().empty()); instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); @@ -1027,7 +1080,7 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreateSort( const Shape& shape, int64 dimension, HloInstruction* keys, - HloInstruction* values) { + absl::Span values) { return absl::make_unique(shape, dimension, keys, values); } @@ -1114,6 +1167,15 @@ bool HloInstruction::HasSideEffect() const { shape, operands, custom_call_target, opaque); } +/* static */ std::unique_ptr HloInstruction::CreateCustomCall( + const Shape& shape, absl::Span operands, + absl::string_view custom_call_target, + absl::Span operand_shapes_with_layout, + absl::string_view opaque) { + return absl::make_unique( + shape, operands, custom_call_target, opaque, operand_shapes_with_layout); +} + /* static */ std::unique_ptr HloInstruction::CreateTuple( absl::Span elements) { std::vector element_shapes; @@ -1432,7 +1494,7 @@ int64 HloInstruction::operand_index(const HloInstruction* target) const { HloInstruction::InstructionVector HloInstruction::unique_operands() const { InstructionVector unique; - tensorflow::gtl::FlatSet seen; + absl::flat_hash_set seen; for (HloInstruction* operand : operands()) { if (seen.insert(operand).second) { unique.push_back(operand); @@ -2006,7 +2068,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( options.is_in_nested_computation()) { str.push_back(PrintName( canonical_name_map->LookupOrInsert(operand->name()), options)); - } else if (!options.compact_operands()) { + } else if (options.print_operand_names()) { str.push_back(PrintName(operand->name(), options)); } StrAppend(out, StrJoin(str, " ")); @@ -2618,7 +2680,6 @@ Status HloInstruction::AcceptOrdered( } const Shape& HloInstruction::shape() const { - TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); return shape_; } @@ -2661,14 +2722,14 @@ class HloInstruction::FusionReusesParamElements { // the value of this parameter, which would save stack space but not allow us // to finish early if we find a reuse. static UseKind Compute(int64 i, const HloInstruction& hlo) { - tensorflow::gtl::FlatMap memoization_cache; + absl::flat_hash_map memoization_cache; return ComputeInternal(i, hlo, &memoization_cache); } private: static UseKind ComputeInternal( int64 i, const HloInstruction& hlo, - tensorflow::gtl::FlatMap* cache) { + absl::flat_hash_map* cache) { if (auto hlo_param = DynCast(&hlo)) { if (hlo_param->parameter_number() == i) { return UseKind::kUse; @@ -3048,10 +3109,6 @@ const std::vector& HloInstruction::slice_strides() const { return Cast(this)->slice_strides(); } -bool HloInstruction::IsInPlaceSlice() const { - return Cast(this)->IsInPlaceSlice(); -} - const Literal& HloInstruction::literal() const { return Cast(this)->literal(); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 009bd3bab3684056247beb361a3b8662e6901f99..15a4da8dbe0053aad314989a6718ebd61532ab8b 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -28,10 +28,10 @@ limitations under the License. #include #include #include -#include -#include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" @@ -50,7 +50,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -80,6 +79,7 @@ class HloPrintOptions { print_backend_config_(true), compact_operands_(false), print_operand_shape_(true), + print_operand_names_(true), print_program_shape_(true), print_percent_(true), print_control_dependencies_(true), @@ -107,6 +107,7 @@ class HloPrintOptions { .set_print_metadata(false) .set_print_backend_config(false) .set_compact_operands(true) + .set_print_operand_names(false) .set_print_operand_shape(true) .set_print_program_shape(false) .set_print_percent(false) @@ -144,6 +145,12 @@ class HloPrintOptions { return *this; } + // If true, the operand names will be printed. + HloPrintOptions& set_print_operand_names(bool value) { + print_operand_names_ = value; + return *this; + } + // If true, program shape of hlo computations will be printed. HloPrintOptions& set_print_program_shape(bool value) { print_program_shape_ = value; @@ -162,8 +169,8 @@ class HloPrintOptions { return *this; } - // If true, only a part of operands will be printed out, and their names will - // be omitted (note that in this case the text will not be parsable). + // If true, only a part of operands will be printed out (note that in this + // case the text will not be parsable). HloPrintOptions& set_compact_operands(bool value) { compact_operands_ = value; return *this; @@ -197,6 +204,7 @@ class HloPrintOptions { bool print_backend_config() const { return print_backend_config_; } bool compact_operands() const { return compact_operands_; } bool print_operand_shape() const { return print_operand_shape_; } + bool print_operand_names() const { return print_operand_names_; } bool print_program_shape() const { return print_program_shape_; } bool print_percent() const { return print_percent_; } bool print_control_dependencies() const { @@ -215,6 +223,7 @@ class HloPrintOptions { bool print_backend_config_; bool compact_operands_; bool print_operand_shape_; + bool print_operand_names_; bool print_program_shape_; bool print_percent_; bool print_control_dependencies_; @@ -247,7 +256,7 @@ class CanonicalNameMap { private: int64 index; - tensorflow::gtl::FlatMap canonical_name_map; + absl::flat_hash_map canonical_name_map; }; // HLO instructions are the atomic unit of the high-level compiler's IR. @@ -350,8 +359,8 @@ class HloInstruction { // calls. static StatusOr> CreateFromProto( const HloInstructionProto& proto, - const tensorflow::gtl::FlatMap& instruction_map, - const tensorflow::gtl::FlatMap& computation_map); + const absl::flat_hash_map& instruction_map, + const absl::flat_hash_map& computation_map); // Creates a parameter-retrieving instruction. static std::unique_ptr CreateParameter(int64 parameter_number, @@ -454,7 +463,7 @@ class HloInstruction { // the same all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will // not be applied cross modules. // - // TODO(b/79737069): Rename this to AllReduce. + // TODO(b/117564385): Rename this to AllReduce. static std::unique_ptr CreateCrossReplicaSum( const Shape& shape, absl::Span operands, HloComputation* reduce_computation, @@ -660,10 +669,10 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, absl::Span dimensions); - // Creates a sort op, with a keys operand, and an optional values operand. + // Creates a sort op, with a keys operand, and optional values operands. static std::unique_ptr CreateSort( const Shape& shape, int64 dimension, HloInstruction* keys, - HloInstruction* values = nullptr); + absl::Span values = {}); // Creates a while instruction, given a condition computation, a body // computation, and the initial value for the input of the computations. For @@ -724,6 +733,16 @@ class HloInstruction { const Shape& shape, absl::Span operands, absl::string_view custom_call_target, absl::string_view opaque = ""); + // Overload which constrains the layouts of the operand and result. 'shape' + // and 'operand_shapes_with_layout' must have layouts. + // 'operand_shapes_with_layout' must have a compatible element for each + // operand. + static std::unique_ptr CreateCustomCall( + const Shape& shape, absl::Span operands, + absl::string_view custom_call_target, + absl::Span operand_shapes_with_layout, + absl::string_view opaque = ""); + // Creates a tuple instruction with the given elements. This is a convenience // wrapper around CreateVariadic. static std::unique_ptr CreateTuple( @@ -1320,9 +1339,6 @@ class HloInstruction { int64 slice_strides(int64 dimension) const; const std::vector& slice_strides() const; - // Delegates to HloSliceInstruction::IsInPlaceSlice. - bool IsInPlaceSlice() const; - // Returns the literal associated with this instruction. const Literal& literal() const; @@ -1628,7 +1644,7 @@ class HloInstruction { // members. The set enables fast membership testing and the vector enables // fast, stable iteration. std::vector users_; - std::unordered_set user_set_; + absl::flat_hash_set user_set_; // The set of control successors of this instruction. std::vector control_successors_; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index c1b7c3832b44b5d65b715dffa5211a5c92e17953..d93351fe0435b5f29035dc4ea0621a8c576bfd5a 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -135,7 +135,8 @@ TEST_F(HloInstructionTest, BasicProperties) { auto parameter = HloInstruction::CreateParameter(1, r0f32_, "foo"); EXPECT_EQ(HloOpcode::kParameter, parameter->opcode()); - EXPECT_TRUE(ShapeUtil::IsScalarF32(parameter->shape())); + EXPECT_TRUE(ShapeUtil::IsScalarWithElementType(parameter->shape(), F32)); + EXPECT_FALSE(ShapeUtil::IsScalarWithElementType(parameter->shape(), S32)); EXPECT_EQ(0, parameter->operand_count()); } diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index cd71bc332319b825e94459fad23e61a2246dd3f7..88495e80000c4f87a778c4fad747f6bdf09b7a14 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" @@ -27,8 +28,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/window_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace { @@ -213,6 +214,7 @@ HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, HloInstructionProto HloSendRecvInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); proto.set_channel_id(channel_id_); + proto.set_is_host_transfer(is_host_transfer_); return proto; } @@ -598,11 +600,11 @@ std::unique_ptr HloReduceInstruction::CloneWithNewOperandsImpl( HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension, HloInstruction* keys, - HloInstruction* values) + absl::Span values) : HloInstruction(HloOpcode::kSort, shape), dimensions_({dimension}) { AppendOperand(keys); - if (values) { - AppendOperand(values); + for (auto* value : values) { + AppendOperand(value); } } @@ -631,9 +633,8 @@ std::unique_ptr HloSortInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { HloInstruction* keys = new_operands[0]; - HloInstruction* values = new_operands.size() == 2 ? new_operands[1] : nullptr; return absl::make_unique(shape, dimensions(0), keys, - values); + new_operands.subspan(1)); } HloTransposeInstruction::HloTransposeInstruction( @@ -641,14 +642,6 @@ HloTransposeInstruction::HloTransposeInstruction( absl::Span dimensions) : HloInstruction(HloOpcode::kTranspose, shape), dimensions_(dimensions.begin(), dimensions.end()) { - CHECK_EQ(shape.dimensions().size(), dimensions.size()); - CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size()); - CHECK(std::equal(operand->shape().dimensions().begin(), - operand->shape().dimensions().end(), - Permute(dimensions, shape.dimensions()).begin())) - << "shape: " << ShapeUtil::HumanString(shape) - << ", operand->shape(): " << ShapeUtil::HumanString(shape) - << ", dimensions: {" << StrJoin(dimensions, ", ") << "}"; AppendOperand(operand); } @@ -1042,7 +1035,8 @@ HloInstruction* HloFusionInstruction::AddFusionOperand( const int64 param_no = operand_count(); // Name the parameter after the instruction it represents in the outer // (non-fusion) computation. - string param_name = StrCat(new_operand->name(), ".param_", param_no); + // string param_name = StrCat(new_operand->name(), ".param_", param_no); + string param_name = StrCat("param_", param_no); HloInstruction* fused_parameter = fused_instructions_computation()->AddParameter( HloInstruction::CreateParameter(param_no, new_operand->shape(), @@ -1098,7 +1092,7 @@ void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput( // Note that we add the unfused instructions to this->parent_ computation. // This is necessary because the unique_id needs for an instruction and // it's only added when inserting to the computation. - tensorflow::gtl::FlatMap old_to_new; + absl::flat_hash_map old_to_new; std::vector unfused_instructions; auto computation_to_merge = instruction_to_merge->fused_instructions_computation(); @@ -1391,7 +1385,7 @@ std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( } Status HloFusionInstruction::DeduplicateFusionOperands() { - tensorflow::gtl::FlatMap operand_indices; + absl::flat_hash_map operand_indices; std::vector operands_to_remove; for (int i = 0; i < operand_count(); ++i) { auto emplace_result = operand_indices.emplace(operand(i), i); @@ -1488,7 +1482,6 @@ HloParameterInstruction::CloneWithNewOperandsImpl( HloGetTupleElementInstruction::HloGetTupleElementInstruction( const Shape& shape, HloInstruction* operand, int64 index) : HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) { - CHECK(ShapeUtil::IsTuple(operand->shape())); AppendOperand(operand); } @@ -1610,9 +1603,6 @@ HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape, : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()), outfeed_shape_(outfeed_shape), outfeed_config_(outfeed_config) { - CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape)) - << "Outfeed shape " << outfeed_shape - << " must be compatible with operand shape " << operand->shape(); AppendOperand(operand); AppendOperand(token_operand); } @@ -1834,7 +1824,24 @@ HloCustomCallInstruction::HloCustomCallInstruction( : HloInstruction(HloOpcode::kCustomCall, shape), custom_call_target_(custom_call_target.begin(), custom_call_target.end()), opaque_(opaque.begin(), opaque.end()), - feature_group_count_(1) { + feature_group_count_(1), + layout_constrained_(false) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +HloCustomCallInstruction::HloCustomCallInstruction( + const Shape& shape, absl::Span operands, + absl::string_view custom_call_target, absl::string_view opaque, + absl::Span operand_shapes_with_layout) + : HloInstruction(HloOpcode::kCustomCall, shape), + custom_call_target_(custom_call_target.begin(), custom_call_target.end()), + opaque_(opaque.begin(), opaque.end()), + feature_group_count_(1), + layout_constrained_(true), + operand_shapes_with_layout_(operand_shapes_with_layout.begin(), + operand_shapes_with_layout.end()) { for (auto operand : operands) { AppendOperand(operand); } @@ -1852,6 +1859,12 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const { proto.set_custom_call_target(custom_call_target_); proto.set_custom_call_opaque(opaque_); proto.set_feature_group_count(feature_group_count_); + if (layout_constrained()) { + proto.set_constrain_layout(true); + for (const Shape& shape : operand_shapes_with_layout_) { + *proto.add_operand_shapes_with_layout() = shape; + } + } return proto; } @@ -1879,6 +1892,14 @@ std::vector HloCustomCallInstruction::ExtraAttributesToStringImpl( if (!opaque_.empty()) { extra.push_back(StrCat("opaque=\"", CEscape(opaque_), "\"")); } + if (layout_constrained()) { + std::vector shape_strings; + for (const Shape& shape : operand_shapes_with_layout_) { + shape_strings.push_back(ShapeUtil::HumanStringWithLayout(shape)); + } + extra.push_back(StrCat("operand_layout_constraints={", + StrJoin(shape_strings, ", "), "}")); + } return extra; } @@ -2309,4 +2330,23 @@ std::unique_ptr HloDomainInstruction::CloneWithNewOperandsImpl( shape, new_operands[0], operand_side_metadata_->Clone(), user_side_metadata_->Clone()); } + +HloInstructionProto HloDomainInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + auto operand_side_sharding = + dynamic_cast(operand_side_metadata_.get()); + if (operand_side_sharding && operand_side_sharding->sharding() != nullptr) { + *proto.mutable_domain_entry_sharding() = + operand_side_sharding->sharding()->ToProto(); + } + + auto user_side_sharding = + dynamic_cast(user_side_metadata_.get()); + if (user_side_sharding && user_side_sharding->sharding() != nullptr) { + *proto.mutable_domain_exit_sharding() = + user_side_sharding->sharding()->ToProto(); + } + + return proto; +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 9c22f5db7e7fbcaff76fc9769d7603ca669e9ee2..5f06dc093248e1d4d36ec845ced1e68c2b9d0752 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -418,14 +418,19 @@ class HloSortInstruction : public HloInstruction { public: explicit HloSortInstruction(const Shape& shape, int64 dimension, HloInstruction* keys, - HloInstruction* values = nullptr); + absl::Span values = {}); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } // Returns the sort dimension for this instruction - int64 sort_dimension() { return dimensions(0); } + int64 sort_dimension() const { return dimensions(0); } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + // Returns the key operand to this instruction. + const HloInstruction* keys() const { return operand(0); } + HloInstruction* mutable_keys() { return mutable_operand(0); } + // Returns the number of value operands. + int64 values_count() const { return operand_count() - 1; } private: std::vector ExtraAttributesToStringImpl( @@ -546,17 +551,6 @@ class HloSliceInstruction : public HloInstruction { } const std::vector& slice_strides() const { return slice_strides_; } - // Returns the flag that describes whether a slice must be lowered into an - // offset into the original operand. - bool IsInPlaceSlice() const { return is_in_place_slice_; } - - // Sets and returns the flag that describes whether a slice must be lowered - // into an offset into the original operand. - bool SetIsInPlaceSlice(bool value) { - is_in_place_slice_ = value; - return value; - } - private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -573,9 +567,6 @@ class HloSliceInstruction : public HloInstruction { std::vector slice_starts_; std::vector slice_limits_; std::vector slice_strides_; - - // Describes whether the slice can be lowered to an offset into the operand. - bool is_in_place_slice_ = false; }; class HloConstantInstruction : public HloInstruction { @@ -910,7 +901,6 @@ class HloOutfeedInstruction : public HloInstruction { absl::string_view outfeed_config); // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const { - TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape_)); return outfeed_shape_; } // Returns the config for the Outfeed instruction. @@ -1068,10 +1058,19 @@ class HloSelectAndScatterInstruction : public HloInstruction { class HloCustomCallInstruction : public HloInstruction { public: - explicit HloCustomCallInstruction(const Shape& shape, - absl::Span operands, - absl::string_view custom_call_target, - absl::string_view opaque); + HloCustomCallInstruction(const Shape& shape, + absl::Span operands, + absl::string_view custom_call_target, + absl::string_view opaque); + + // Constructor for a custom call with constrained layout. 'shape' and + // 'operands_with_layout' must all have layouts. + HloCustomCallInstruction(const Shape& shape, + absl::Span operands, + absl::string_view custom_call_target, + absl::string_view opaque, + absl::Span operand_shapes_with_layout); + const Window& window() const override { CHECK(window_ != nullptr); return *window_; @@ -1100,6 +1099,16 @@ class HloCustomCallInstruction : public HloInstruction { // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + // Returns whether the result and operand layouts are constrained. + bool layout_constrained() const { return layout_constrained_; } + + // Returns the shapes (with layout) of the operands. CHECKs if this custom + // call does not have constrained layouts. + const std::vector& operand_shapes_with_layout() const { + CHECK(layout_constrained()); + return operand_shapes_with_layout_; + } + private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -1121,6 +1130,11 @@ class HloCustomCallInstruction : public HloInstruction { std::unique_ptr convolution_dimension_numbers_; // The number of feature groups. This is used for grouped convolutions. int64 feature_group_count_; + // Whether the result and operand layouts are constrained. + bool layout_constrained_; + // For layout-constrained custom calls, this vector holds the shape with + // layout for each operand. + std::vector operand_shapes_with_layout_; }; class HloPadInstruction : public HloInstruction { @@ -1341,6 +1355,9 @@ class HloDomainInstruction : public HloInstruction { std::unique_ptr operand_side_metadata, std::unique_ptr user_side_metadata); + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + // Retrieves the operand side metadata of a kDomain instruction. const DomainMetadata& operand_side_metadata() const { return *operand_side_metadata_; diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index d9be841dd751651ba029998fd062fcaec3691945..971a9a20636c80820306d512af9e7ff4a14b79b5 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -204,7 +204,7 @@ TokKind HloLexer::LexIdentifier() { auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); // 'consumable' will be advanced iff its prefix matches the pattern. static LazyRE2 shape_pattern = { - R"(^(\w*\d*)\[([\d,]*)\](?:(dense|sparse)?{([\d,]+)})?)"}; + R"(^(\w*\d*)\[([\d,\s]*)\](?:(dense|sparse)?{([\d,\s]+)})?)"}; if (RE2::Consume(&consumable, *shape_pattern)) { auto status_or_shape = ShapeUtil::ParseShapeString( StringPieceFromPointers(token_start_, consumable.begin())); diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 5502e565b6dfbaca6cfa2101950fb0a68c89771f..1717770301e3666b0a1c23d20b7f2e3bac5f62e4 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -179,6 +179,7 @@ HLO_MATCHER(Convolution); HLO_MATCHER(Copy); HLO_MATCHER(CrossReplicaSum); HLO_MATCHER(Divide); +HLO_MATCHER(Domain); HLO_MATCHER(DynamicSlice); HLO_MATCHER(DynamicUpdateSlice); HLO_MATCHER(Eq); @@ -216,6 +217,7 @@ HLO_MATCHER(Remainder); HLO_MATCHER(Reshape); HLO_MATCHER(Reverse); HLO_MATCHER(Rng); +HLO_MATCHER(Scatter); HLO_MATCHER(Select); HLO_MATCHER(SelectAndScatter); HLO_MATCHER(Send); diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc index 6a4e766788f47cad9e168fcccd3a3de9097cacdc..5cee865b7ad34eded1743d9d5455bb40febf6182 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" @@ -74,7 +76,7 @@ class ListScheduler { const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation) { ListScheduler scheduler(computation, points_to_analysis, size_function, memory_by_computation); @@ -99,7 +101,7 @@ class ListScheduler { ListScheduler(const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation) : computation_(computation), points_to_analysis_(points_to_analysis), @@ -110,7 +112,7 @@ class ListScheduler { // LogicalBuffer is in an operand of the instruction as indicated by // points-to analysis. for (auto* instruction : computation.instructions()) { - tensorflow::gtl::FlatSet instr_uses; + absl::flat_hash_set instr_uses; for (auto* operand : instruction->operands()) { points_to_analysis.GetPointsToSet(operand).ForEachElement( [&](const ShapeIndex& /*index*/, @@ -193,13 +195,15 @@ class ListScheduler { return entry; } - // Returns the number of bytes freed if the HLO instruction is scheduled. - // If the instruction calls subcomputations, we count the memory used by the - // subcomputations as memory "defined" by the instruction. This is not - // entirely accurate, because subcomputation memory will be freed after the - // instruction finishes. But it is more accurate than not taking - // subcomputations into account at all. In the future, we may improve - // accounting for subcomputation memory (b/65409243). + // Returns the number of bytes freed *after* the HLO instruction finishes. + // The current List algorithm only considers two states for an instruction: + // right before it runs, and after it finishes. We don't represent memory + // usage during the execution of an instruction. But if the instruction calls + // subcomputations, they are only live during the instruction's execution. + // We end up counting the memory used by subcomputations as memory "defined" + // by the instruction. This is not entirely accurate, but it is more accurate + // than not taking subcomputations into account at all. In the future, we may + // improve accounting for subcomputation memory (b/65409243). int64 BytesFreedIfScheduled(const ReadyListEntry& entry) { int64 freed_bytes = 0; for (const auto& kv : entry.used_buffer_unscheduled_use_counts) { @@ -221,7 +225,18 @@ class ListScheduler { } } } - return freed_bytes - entry.bytes_defined - max_subcomputation_bytes; + int64 bytes_defined; + if (max_subcomputation_bytes > 0 && + (entry.instruction->opcode() == HloOpcode::kWhile || + entry.instruction->opcode() == HloOpcode::kCall || + entry.instruction->opcode() == HloOpcode::kConditional)) { + // The output buffer of while/call/conditional is always aliased with the + // output buffer of the root instruction in the body. Don't double count. + bytes_defined = max_subcomputation_bytes; + } else { + bytes_defined = entry.bytes_defined + max_subcomputation_bytes; + } + return freed_bytes - bytes_defined; } // Constructs the scheduling priority of the given instruction. @@ -234,8 +249,7 @@ class ListScheduler { // Populate the ready list with instructions which have no operands or // control predecessors. - tensorflow::gtl::FlatMap - unscheduled_pred_count; + absl::flat_hash_map unscheduled_pred_count; for (auto* instruction : computation_.instructions()) { // TODO(b/34466113): Replace this and above with successors() or // predecessors() when these methods are added to HloInstruction. @@ -251,8 +265,8 @@ class ListScheduler { std::multimap ready_queue; // Map of ready instructions to their iterators in ready_queue. - tensorflow::gtl::FlatMap::iterator> + absl::flat_hash_map::iterator> ready_instructions; auto add_to_ready_queue = [&](HloInstruction* inst) { @@ -262,9 +276,8 @@ class ListScheduler { }; for (auto* instruction : computation_.instructions()) { - // Instruction with no operands or control predecessors will - // not be in the map. - if (unscheduled_pred_count.count(instruction) == 0) { + if (instruction->operands().empty() && + instruction->control_predecessors().empty()) { add_to_ready_queue(instruction); } } @@ -347,21 +360,19 @@ class ListScheduler { // Computations are analyzed in post-order. When scheduling an instruction // that includes subcomputations, such as a while loop, we use this map to // look up the memory needed by subcomputations. - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation_; // A map containing the LogicalBuffers that each instruction uses. - tensorflow::gtl::FlatMap> + absl::flat_hash_map> buffer_uses_; // A map containing the count of unscheduled HLOs which using a particular - // LogicalBuffer. We rely on iterator stability in this map, and that the map - // entries are std::pair's. - std::unordered_map unscheduled_use_count_; + // LogicalBuffer. + absl::flat_hash_map unscheduled_use_count_; // Set of instructions which have been scheduled. - tensorflow::gtl::FlatSet scheduled_instructions_; + absl::flat_hash_set scheduled_instructions_; }; int64 SumLogicalBufferSizes( @@ -379,7 +390,7 @@ StatusOr ScheduleComputationHelper( const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation) { VLOG(2) << "Computation: " << computation.name(); if (algorithm) { @@ -396,13 +407,13 @@ StatusOr DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation) { // These variables are a hack to prevent overflows. int64 cumulative_total_size = 0; int64 total_hlos = computation.parent()->instruction_count(); - tensorflow::gtl::FlatMap extra_users; - tensorflow::gtl::FlatMap total_sizes; + absl::flat_hash_map extra_users; + absl::flat_hash_map total_sizes; for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { if (ListScheduler::IgnoreInstruction(*hlo)) { extra_users[hlo] = 0; @@ -419,7 +430,7 @@ StatusOr DFSMemoryScheduler( points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); total_sizes[hlo] = logical_buffer_size; cumulative_total_size += logical_buffer_size; - tensorflow::gtl::FlatSet unique_operands( + absl::flat_hash_set unique_operands( hlo->operands().begin(), hlo->operands().end()); for (const HloInstruction* operand : unique_operands) { extra_users[hlo] += extra_users[operand]; @@ -467,7 +478,7 @@ StatusOr ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation) { return ListScheduler::Run(computation, points_to_analysis, size_function, memory_by_computation); @@ -477,7 +488,7 @@ StatusOr PostOrderMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation) { return HloInstructionSequence(computation.MakeInstructionPostOrder()); } @@ -486,7 +497,7 @@ StatusOr DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation) { // We try a few schedulers and choose whichever returns a lower min-memory, // not accounting for fragmentation. @@ -549,7 +560,7 @@ StatusOr ScheduleModule( HloSchedule schedule(&module); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); - tensorflow::gtl::FlatMap memory_by_computation; + absl::flat_hash_map memory_by_computation; for (const auto* computation : module.MakeComputationPostOrder()) { if (!computation->IsFusionComputation()) { TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence, @@ -577,7 +588,7 @@ StatusOr ScheduleComputation( CHECK(!computation.IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(computation.parent())); - tensorflow::gtl::FlatMap empty_map; + absl::flat_hash_map empty_map; return ScheduleComputationHelper(computation, *points_to_analysis, size_function, nullptr, empty_map); } diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h index 9964c6fdd7c60a807896ea7aaaa9d55767f20f51..a4c1d3db8170a1725043def576f913e09b352e5d 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" @@ -37,7 +38,7 @@ namespace xla { typedef std::function( const HloComputation&, const TuplePointsToAnalysis&, const LogicalBuffer::SizeFunction&, - const tensorflow::gtl::FlatMap&)> + const absl::flat_hash_map&)> MemorySchedulerAlgorithm; // List scheduler @@ -45,7 +46,7 @@ StatusOr ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation); // DFS-order scheduler @@ -53,7 +54,7 @@ StatusOr DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation); // Naive Post Order scheduler @@ -61,7 +62,7 @@ StatusOr PostOrderMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation); // The default scheduling algorithm. Runs both the list scheduler @@ -71,7 +72,7 @@ StatusOr DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation); // Returns an HloSchedule which seeks to minimize the memory required for diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc index 1b9e9bfc77c3ba91e5b878f4aa42d26d8267a49a..214119fba881c4411a262cd4227b5cc49cef0d14 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -146,126 +147,6 @@ ENTRY root { instructions_by_name.at("e"))); } -TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { - // %WhileCond (cond_param: f32[4]) -> pred[] { - // %cond_param = f32[4]{0} parameter(0) - // %constant = f32[1,4]{1,0} constant(f32[1,4] { { 0, 0, 0, 0 } }) - // ROOT %not-equal-to = pred[] not-equal-to( - // f32[4]{0} %cond_param, f32[1,4]{1,0} %constant) - // } - // %WhileBody (body_param: f32[4]) -> f32[4] { - // %body_param = f32[4]{0} parameter(0) - // %constant.1 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } }) - // ROOT %subtract = f32[4]{0} subtract( - // f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1) - // } - // %ListAccountsForSubcomputations () -> f32[2,4] { - // %constant.3 = f32[2,4]{1,0} constant( - // f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } }) - // %transpose = f32[2,4]{1,0} transpose( - // f32[2,4]{1,0} %constant.3), dimensions={0,1} - // %constant.2 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } }) - // %while = f32[4]{0} while(f32[1,4]{1,0} %constant.2), - // condition=%WhileCond, - // body=%WhileBody - // %broadcast = f32[2,4]{1,0} broadcast(f32[4]{0} %while), dimensions={0} - // ROOT %add = f32[2,4]{1,0} add( - // f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast) - // } - - auto module = CreateNewModule(); - const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); - const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); - - // param != 0 - // Needs 17 bytes - auto cond_builder = HloComputation::Builder("WhileCond"); - HloInstruction* cond_param = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, r1f32, "cond_param")); - HloInstruction* zero_vector = - cond_builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{0, 0, 0, 0}}))); - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); - auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); - - // param - 1 - // Needs 16 bytes - auto body_builder = HloComputation::Builder("WhileBody"); - HloInstruction* body_param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, r1f32, "body_param")); - HloInstruction* one_vector = - body_builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 1, 1, 1}}))); - body_builder.AddInstruction(HloInstruction::CreateBinary( - r1f32, HloOpcode::kSubtract, body_param, one_vector)); - auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); - - // transpose(matrix) + bcast(while) - auto builder = HloComputation::Builder(TestName()); - HloInstruction* while_init = - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 1, 1, 1}}))); - // Creates 16 bytes, ignoring subcomputations - HloInstruction* while_loop = - builder.AddInstruction(HloInstruction::CreateWhile( - r1f32, cond_computation, body_computation, while_init)); - - // Creates 32 bytes and frees 16 - HloInstruction* bcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(r2f32, while_loop, {0})); - - HloInstruction* matrix = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2( - {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}}))); - // Creates 32 bytes - HloInstruction* transpose = builder.AddInstruction( - HloInstruction::CreateTranspose(r2f32, matrix, {0, 1})); - - // Creates 32 bytes and frees 64 - HloInstruction* add = builder.AddInstruction( - HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast)); - - module->AddEntryComputation(builder.Build()); - - auto size_fn = [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - }; - TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(*module, size_fn, ListMemoryScheduler)); - // Verify that all instructions are in the sequence. - auto entry_computation = module->entry_computation(); - EXPECT_EQ(entry_computation->instruction_count(), - schedule.sequence(entry_computation).size()); - SequentialHloOrdering ordering(schedule); - // This schedule is an example of List's greedy heuristics being suboptimal. - // The while_loop is more expensive than transpose, so it would have been - // better to schedule it first, instead of during the busy time. - EXPECT_TRUE(ordering.ExecutesBefore(transpose, while_loop)); - EXPECT_TRUE(ordering.ExecutesBefore(transpose, bcast)); - EXPECT_TRUE(ordering.ExecutesBefore(bcast, add)); - EXPECT_TRUE(ordering.ExecutesBefore(transpose, add)); - - tensorflow::gtl::FlatMap memory_by_computation; - memory_by_computation[cond_computation] = 17; - memory_by_computation[body_computation] = 16; - std::unique_ptr points_to_analysis = - TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); - - // HeapSimulator doesn't account for subcomputations - EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, schedule.sequence(entry_computation), - *points_to_analysis, size_fn) - .ValueOrDie()); - // HeapSimulator accounts for subcomputations. The output buffer is aliased, - // so we don't double count. - EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, schedule.sequence(entry_computation), - *points_to_analysis, size_fn, &memory_by_computation) - .ValueOrDie()); -} - TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { auto builder = HloComputation::Builder(TestName()); const auto TUPLE_SIZE = 1; @@ -409,7 +290,7 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { EXPECT_EQ(module->entry_computation()->instruction_count(), schedule.sequence(module->entry_computation()).size()); - tensorflow::gtl::FlatMap memory_by_computation; + absl::flat_hash_map memory_by_computation; memory_by_computation[cond_computation] = 17; memory_by_computation[body_computation] = 16; std::unique_ptr points_to_analysis = diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index b3949f3a6d7176950c61cafb0830d1175f17758d..6845c27a91845ef971dc2d82266200bfccb25533 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -23,6 +23,8 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" @@ -71,6 +73,8 @@ HloComputation* HloModule::AddComputationInternal( config_.SetDefaultComputationLayout( entry_computation_->ComputeProgramShape()); } + input_output_alias_config_ = HloInputOutputAliasConfig( + entry_computation_->root_instruction()->shape()); } if (uniquify_identifiers) { @@ -144,7 +148,8 @@ void HloModule::ReplaceComputations( case HloOpcode::kCall: case HloOpcode::kMap: case HloOpcode::kReduce: - case HloOpcode::kReduceWindow: { + case HloOpcode::kReduceWindow: + case HloOpcode::kScatter: { HloComputation* new_arg = tensorflow::gtl::FindWithDefault( replacements, instruction->to_apply(), nullptr); if (new_arg != nullptr) { @@ -241,14 +246,14 @@ HloModuleProto HloModule::ToProto() const { proto.set_entry_computation_id(entry_computation_->unique_id()); for (const HloComputation* computation : MakeComputationPostOrder()) { HloComputationProto computation_proto = computation->ToProto(); - if (computation->name() == entry_computation_->name()) { - *proto.mutable_program_shape() = computation_proto.program_shape(); - } proto.add_computations()->Swap(&computation_proto); } if (has_schedule()) { *proto.mutable_schedule() = schedule().ToProto().ValueOrDie(); } + *proto.mutable_host_program_shape() = + entry_computation_layout().ComputeProgramShape(); + *proto.mutable_input_output_alias() = input_output_alias_config().ToProto(); return proto; } @@ -260,9 +265,9 @@ StatusOr> HloModule::CreateFromProto( // The ProgramShape in the passed in module config must match the shapes of // the entry parameters and root. - TF_RET_CHECK(proto.has_program_shape()) + TF_RET_CHECK(proto.has_host_program_shape()) << "No program shape found in the proto"; - const auto& expected_program_shape = proto.program_shape(); + const auto& expected_program_shape = proto.host_program_shape(); TF_RET_CHECK(expected_program_shape.parameters_size() == module_config.entry_computation_layout().parameter_count()); for (int i = 0; i < expected_program_shape.parameters_size(); ++i) { @@ -285,8 +290,8 @@ StatusOr> HloModule::CreateFromProto( << ShapeUtil::HumanStringWithLayout(expected_program_shape.result()) << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape); - tensorflow::gtl::FlatMap computation_map; - tensorflow::gtl::FlatMap to_proto_id; + absl::flat_hash_map computation_map; + absl::flat_hash_map to_proto_id; std::vector> computations; HloComputation* entry = nullptr; for (const HloComputationProto& computation_proto : proto.computations()) { @@ -325,12 +330,16 @@ StatusOr> HloModule::CreateFromProto( } TF_RET_CHECK(module->entry_computation_ != nullptr); + TF_ASSIGN_OR_RETURN(module->input_output_alias_config_, + HloInputOutputAliasConfig::CreateFromProto( + result_shape, proto.input_output_alias())); + // Because we didn't uniquify the names or the ids, double-check that the // instruction and computation names and ids are unique from the proto. - tensorflow::gtl::FlatSet computation_names; - tensorflow::gtl::FlatSet instruction_names; - tensorflow::gtl::FlatSet computation_ids; - tensorflow::gtl::FlatSet instruction_ids; + absl::flat_hash_set computation_names; + absl::flat_hash_set instruction_names; + absl::flat_hash_set computation_ids; + absl::flat_hash_set instruction_ids; for (HloComputation* computation : module->computations()) { TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) << "Computation name is not unique: " << computation->name(); @@ -363,9 +372,9 @@ StatusOr> HloModule::CreateFromProto( /* static */ StatusOr HloModule::CreateModuleConfigFromProto( const HloModuleProto& module, const DebugOptions& debug_options) { - TF_RET_CHECK(module.has_program_shape()) + TF_RET_CHECK(module.has_host_program_shape()) << "No program shape found in the proto"; - const auto& program_shape = module.program_shape(); + const auto& program_shape = module.host_program_shape(); HloModuleConfig module_config(program_shape); module_config.set_debug_options(debug_options); @@ -555,8 +564,13 @@ std::vector HloModule::MakeNonfusionComputations() const { } std::unique_ptr HloModule::Clone(const string& suffix) const { + return Clone(config(), suffix); +} + +std::unique_ptr HloModule::Clone(const HloModuleConfig& config, + const string& suffix) const { VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; - auto module = absl::make_unique(name_ + "-" + suffix, config_); + auto module = absl::make_unique(name_ + "-" + suffix, config); HloCloneContext context(module.get(), suffix); auto cloned_computation = entry_computation_->Clone(suffix, &context); diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 735804e827afd77e2b7f2a4a7d490ee6f5ee7b4f..5dc795fabec5d8d794635ef6965c4d065b0b75a6 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" @@ -92,6 +93,8 @@ class HloModule { // Returns a deep copy of this module including all computations. std::unique_ptr Clone(const string& suffix = "clone") const; + std::unique_ptr Clone(const HloModuleConfig& config, + const string& suffix = "clone") const; // Performs a deep clone of the computation, by recursively cloning all // the called computations as well. If the clone context is specified, it @@ -99,7 +102,7 @@ class HloModule { HloComputation* DeepCloneComputation(HloComputation* computation, HloCloneContext* context = nullptr); - // Return a pointer to the entry computation of the module.. + // Return a pointer to the entry computation of the module. const HloComputation* entry_computation() const { CHECK_NE(nullptr, entry_computation_); return entry_computation_; @@ -109,6 +112,14 @@ class HloModule { return entry_computation_; } + // Returns the root instruction shape of entry computation. + // + // Precondition: entry_computation_ is not nullptr. + const Shape& result_shape() const { + CHECK_NE(nullptr, entry_computation_); + return entry_computation()->root_instruction()->shape(); + } + // Creates the ComputationLayout which describes the current status of the HLO // module entry computation. ComputationLayout compute_computation_layout() const { @@ -212,9 +223,14 @@ class HloModule { return result; } - // Returns the number of unique intruction ids given out. All ids up to - // this point are guaranteed to be in the range [0..NumUniqueInstructionIds()) - int NumUniqueInstructionIds() const { return next_unique_id_; } + // input_output_alias_config indicates the list of aliased buffers that are + // expected from the module. + HloInputOutputAliasConfig& input_output_alias_config() { + return input_output_alias_config_; + } + const HloInputOutputAliasConfig& input_output_alias_config() const { + return input_output_alias_config_; + } // Returns an id that is unique to this module across all modules created over // the lifetime of this process. @@ -284,6 +300,10 @@ class HloModule { // sequential order of instructions for each non-fusion computation in the // module. absl::optional schedule_; + + // alias_config indicates the alias information of input/output buffers that + // are expected from the module. + HloInputOutputAliasConfig input_output_alias_config_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group.cc b/tensorflow/compiler/xla/service/hlo_module_group.cc index f9b56ef4643f2ca88e56456ae6c990161adb5085..8999ac9f324ed24cf34ef6826000e1fa4f741e19 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group.cc @@ -17,9 +17,8 @@ limitations under the License. namespace xla { -HloModuleGroup::HloModuleGroup(absl::string_view name, - std::unique_ptr module) - : name_(name) { +HloModuleGroup::HloModuleGroup(std::unique_ptr module) + : name_(module->name()) { push_back(std::move(module)); } diff --git a/tensorflow/compiler/xla/service/hlo_module_group.h b/tensorflow/compiler/xla/service/hlo_module_group.h index 7338be8b9c5ed47f0ba5829cc1d603b21f00b6e0..7c39cf17815aa08742e6d5b35941d8043531d034 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group.h +++ b/tensorflow/compiler/xla/service/hlo_module_group.h @@ -35,7 +35,7 @@ class HloModuleGroup { explicit HloModuleGroup(absl::string_view name) : name_(name) {} // Construct a module group containing a single module. - HloModuleGroup(absl::string_view name, std::unique_ptr module); + explicit HloModuleGroup(std::unique_ptr module); // Construct a module group containing any number of modules. HloModuleGroup(absl::string_view name, @@ -50,11 +50,16 @@ class HloModuleGroup { // Add a module to the back of vector of modules in the group. void push_back(std::unique_ptr module); + // Replaces the existing module at the given index with the given module. The + // existing module is discarded. + void ReplaceModule(int index, std::unique_ptr module); + // Moves all modules from the group into the returned vector. After this // method runs, the module group will be empty. std::vector> ConsumeModules(); string name() const { return name_; } + string ToString() const; // Serialize the module group to/from a proto. @@ -63,6 +68,12 @@ class HloModuleGroup { const HloModuleGroupProto& proto, absl::Span module_configs); + // Returns the number of modules in the module group. + int size() const { return modules_.size(); } + + // Returns true if there are no modules in the module group. + bool empty() const { return modules_.empty(); } + private: string name_; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index 83352ef91b35b61ee2560b1488ee2ecdff6bea0a..b4aac4c8076cb69647d42c6243bc969d06d0709e 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -59,7 +59,7 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const { } /* static */ StatusOr> -HloModuleGroupMetadata::Build(const std::vector& modules) { +HloModuleGroupMetadata::Build(absl::Span modules) { auto metadata = absl::make_unique(modules); TF_RETURN_IF_ERROR(metadata->Build()); return std::move(metadata); diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index 278d94cdd337c835bc0ff98ea577ef7b8c3ddd03..928df0f5a7444ad877961a5de970c752e1d024da 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -102,14 +102,14 @@ class HloModuleGroupMetadata { HloInstruction* recv_done = nullptr; }; - explicit HloModuleGroupMetadata(const std::vector& modules) - : modules_(modules) {} + explicit HloModuleGroupMetadata(absl::Span modules) + : modules_(modules.begin(), modules.end()) {} ~HloModuleGroupMetadata() = default; // Build and return the metadata for the given modules. static StatusOr> Build( - const std::vector& modules); + absl::Span modules); // Returns true if the instruction is one of the 4 channel instructions (Send, // Recv, SendDone, RecvDone). @@ -250,33 +250,33 @@ class HloModuleGroupMetadata { std::vector>> companion_sets_; // Map from each companion while instruction to the index into companion_set_. - tensorflow::gtl::FlatMap companion_set_index_; + absl::flat_hash_map companion_set_index_; // Map from computation to the instruction using it (a kWhile, kConditional). - tensorflow::gtl::FlatMap + absl::flat_hash_map tracked_instructions_; // Maps tracked instructions (kWhile, kConditional, kCall, ...) to the set of // communicating instructions within the proper called computation(s). - tensorflow::gtl::FlatMap> + absl::flat_hash_map> tracked_instructions_comms_; // All channels in the module. std::vector channels_; // Map from channel ids to the index in channels_. - tensorflow::gtl::FlatMap channel_id_map_; + absl::flat_hash_map channel_id_map_; // Map from all-reduce ids to the all reduce instructions. - tensorflow::gtl::FlatMap> all_reduce_map_; + absl::flat_hash_map> all_reduce_map_; // The maximum channel id used in the module group. int64 max_channel_id_ = -1; // The modules that this metadata was built from. - const std::vector& modules_; + const std::vector modules_; - tensorflow::gtl::FlatMap> + absl::flat_hash_map> points_to_analyses_; }; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_test.cc b/tensorflow/compiler/xla/service/hlo_module_group_test.cc index b7b12cb72b8df4610b964fb842da78e160d22d9f..5a9a86af5649bf240bb5de6d30fc80b0f6a58eba 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_test.cc @@ -46,7 +46,7 @@ ENTRY %entry (x: f32[], y: f32[]) -> f32[] { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseHloString(text)); - HloModuleGroup group(TestName(), std::move(module)); + HloModuleGroup group(std::move(module)); EXPECT_EQ(group.modules().size(), 1); EXPECT_THAT( diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index d83ee714905252e36f38438e81002a4d6ba7dafa..fddeb5f0a27a43ff9ca8b2b5d314bcfe91aaf0e6 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -42,7 +42,7 @@ std::vector HloModuleGroupUtil::GlobalPredecessors( HloInstruction* instruction) { std::vector predecessors; // Use a vector to avoid non-determinism. - tensorflow::gtl::FlatSet unique; + absl::flat_hash_set unique; // Adds to the unique predecessors list; if the predecessors is a companion // instruction, also add companion instructions; if the predecessors is a @@ -119,7 +119,7 @@ std::vector HloModuleGroupUtil::GlobalSuccessors( HloInstruction* instruction) { std::vector successors; // Use a vector to avoid non-determinism. - tensorflow::gtl::FlatSet unique; + absl::flat_hash_set unique; // Adds to the unique successors list; if the successor is a companion // instruction, also add companion instructions; if the successor is a diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.h b/tensorflow/compiler/xla/service/hlo_module_group_util.h index 309c23045d1e0dd91e2f245d00c51d9bf9961bf5..f21b44bcd98d77b831de5d8a6afa4f9ddd91d15d 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -87,7 +87,7 @@ class HloModuleGroupUtil { // * visit_state: map from each instruction to its visit state. // * visit_function: function called when each instruction group. // * root: the root instruction of the traversal. - using VisitStates = tensorflow::gtl::FlatMap; + using VisitStates = absl::flat_hash_map; Status VisitTopologicalOrder(VisitStates* visit_state, const VisitFunction& visit_function, HloInstruction* root); diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index 2d4e38589fe4693e73c46d6c82e51cb0a8388f85..4551a1c2e259b06818f913cb6a9e782436b7e594 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -31,7 +31,7 @@ string HloOpcodeString(HloOpcode opcode) { } StatusOr StringToHloOpcode(const string& opcode_name) { - static auto* opcode_map = new tensorflow::gtl::FlatMap({ + static auto* opcode_map = new absl::flat_hash_map({ #define STRING_TO_OPCODE_ENTRY(enum_name, opcode_name, ...) \ {opcode_name, HloOpcode::enum_name}, HLO_OPCODE_LIST(STRING_TO_OPCODE_ENTRY) diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index f1dc08bafa17a2dd68a7e922d4b84658bbf2589c..23d41d91d6969ddf9062507e926ae39c1e1315d4 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -92,14 +92,18 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, } bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const { - // If 'b' is an entry param then 'a' cannot be defined before 'b' because 'b' - // is live into the module. + // Entry parameter should always be defined before other instructions. const HloModule* module = b.defining_instruction()->parent()->parent(); if (b.defining_instruction()->parent() == module->entry_computation() && b.defining_instruction()->opcode() == HloOpcode::kParameter) { return false; } + if (a.defining_instruction()->parent() == module->entry_computation() && + a.defining_instruction()->opcode() == HloOpcode::kParameter) { + return true; + } + // Phi values require special handling. Because XLA does not have a phi // instruction, the definition instruction of the phis values are // placeholders: either the subcomputation parameter (body or condition) or @@ -316,7 +320,7 @@ string PredecessorHloOrdering::ToStringHelper(const string& name) const { for (auto predecessor : all) { if (predecessors_.at(computation) ->IsReachable(predecessor, instruction)) { - pieces.push_back(absl::StrFormat(" %s", predecessor->name())); + pieces.push_back(absl::StrFormat(" %s", predecessor->name())); } } } diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index b0361c3f02922bcaa14d52ad3b240701080f9b58..66313492eb2dd10ac9a6000639ddb8991b367c0f 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -120,8 +120,8 @@ class PredecessorHloOrdering : public HloOrdering { // predecessors. An instruction is an element of its own predecessor set. // // Subclasses should fill this in to define the desired ordering. - tensorflow::gtl::FlatMap> + absl::flat_hash_map> predecessors_; }; @@ -204,7 +204,7 @@ class SequentialHloOrdering : public HloOrdering { // this map so more than one instruction may have the same position // value. This is not a problem because ExecutesBefore also verifies // instructions are in the same computation. - tensorflow::gtl::FlatMap order_position_; + absl::flat_hash_map order_position_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 00970bcda34209d33867099d0bcf3b2902d52ae8..b045adc9640ac0ca8cf4a127fea2fbfcbb1aaf3f 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -174,6 +174,26 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) { EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param)); } +TEST_F(HloOrderingTest, ParametersDefinedBeforeOthers) { + // Entry parameter should always be defined before other instruction. + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + module->AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + + DependencyHloOrdering ordering(module.get()); + EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(param), + dataflow->GetValueDefinedAt(constant))); + EXPECT_TRUE(!ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant), + dataflow->GetValueDefinedAt(param))); +} + TEST_F(HloOrderingTest, ValuesInWhileComputations) { // Tests the ordering of values (defined by dataflow analysis) in the body and // condition of a while instruction. HLO code: diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 25b70740e375cbba8f24b136fb65bdbd038ef958..81f091238e5725f64b953f70b82d52cc90aef8ea 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -68,7 +68,7 @@ class HloParser { // Runs the parser and constructs the resulting HLO in the given (empty) // HloModule. Returns false if an error occurred. - bool Run(HloModule* module); + Status Run(HloModule* module); // Returns the error information. string GetError() const { return StrJoin(error_, "\n"); } @@ -79,28 +79,37 @@ class HloParser { StatusOr ParseConvolutionDimensionNumbersOnly(); StatusOr ParsePaddingConfigOnly(); - // Stand-alone parsing utility for a single instruction worth of text. - Status ParseSingleInstruction(HloComputation::Builder* builder, - string* root_name); - private: - // Locates an instruction with the given name in the instruction_pool_ or + using InstrNameTable = + std::unordered_map>; + + // Returns the map from the instruction name to the instruction itself and its + // location in the current scope. + InstrNameTable& current_name_table() { return scoped_name_tables_.back(); } + + // Locates an instruction with the given name in the current_name_table() or // returns nullptr. // - // If the missing_instruction_hook_ is registered and a "shape" is provided, - // the hook will be called and may satisfy the request for the given - // instruction. This is useful when we reify parameters as they're resolved; - // i.e. for ParseSingleInstruction. + // When the name is not found or name is empty, if create_missing_instruction_ + // hook is registered and a "shape" is provided, the hook will be called to + // create an instruction. This is useful when we reify parameters as they're + // resolved; i.e. for ParseSingleInstruction. std::pair* FindInstruction( const string& name, const optional& shape = nullopt); + // Parse a single instruction worth of text. + bool ParseSingleInstruction(HloModule* module); + // ParseXXX returns false if an error occurred. bool ParseHloModule(HloModule* module); + bool ParseComputations(HloModule* module); bool ParseComputation(HloComputation** entry_computation); - bool ParseInstructionList(HloComputation::Builder* builder, - string* root_name); + bool ParseInstructionList(HloComputation** computation, + const string& computation_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); + bool ParseInstruciontRhs(HloComputation::Builder* builder, const string& name, + LocTy name_loc); bool ParseControlPredecessors(HloInstruction* instruction); bool ParseLiteral(Literal* literal, const Shape& shape); bool ParseTupleLiteral(Literal* literal, const Shape& shape); @@ -165,6 +174,7 @@ class HloParser { kDistribution, kDomain, kPrecisionList, + kShapeList }; struct AttrConfig { @@ -231,6 +241,7 @@ class HloParser { bool ParseSliceRanges(SliceRanges* result); bool ParsePrecisionList(std::vector* result); + bool ParseShapeList(std::vector* result); bool ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector* result); @@ -281,23 +292,47 @@ class HloParser { bool AddComputation(const string& name, HloComputation* computation, LocTy name_loc); - // The map from the instruction/computation name to the - // instruction/computation itself and it's location. This does not own the - // pointers. - std::unordered_map> - instruction_pool_; + HloLexer lexer_; + + // A stack for the instruction names. The top of the stack stores the + // instruction name table for the current scope. + // + // A instruction's name is unique among its scope (i.e. its parent + // computation), but it's not necessarily unique among all computations in the + // module. When there are multiple levels of nested computations, the same + // name could appear in both an outer computation and an inner computation. So + // we need a stack to make sure a name is only visible within its scope, + std::vector scoped_name_tables_; + + // A helper class which pushes and pops to an InstrNameTable stack via RAII. + class Scope { + public: + explicit Scope(std::vector* scoped_name_tables) + : scoped_name_tables_(scoped_name_tables) { + scoped_name_tables_->emplace_back(); + } + ~Scope() { scoped_name_tables_->pop_back(); } + + private: + std::vector* scoped_name_tables_; + }; + + // Map from the computation name to the computation itself and its location. std::unordered_map> computation_pool_; - HloLexer lexer_; std::vector> computations_; std::vector error_; - // Function that gets invoked when we try to resolve an instruction - // instruction_pool_ but fail to do so. - std::function*(string, - const optional&)> - missing_instruction_hook_; + // When an operand name cannot be resolved, this function is called to create + // a parameter instruction with the given name and shape. It registers the + // name, instruction, and a placeholder location in the name table. It returns + // the newly-created instruction and the placeholder location. If `name` is + // empty, this should create the parameter with a generated name. This is + // supposed to be set and used only in ParseSingleInstruction. + std::function*(const string& name, + const Shape& shape)> + create_missing_instruction_; }; bool SplitToInt64s(absl::string_view s, char delim, std::vector* out) { @@ -344,18 +379,44 @@ bool HloParser::TokenError(absl::string_view msg) { return Error(lexer_.GetLoc(), msg); } -bool HloParser::Run(HloModule* module) { +Status HloParser::Run(HloModule* module) { lexer_.Lex(); - return ParseHloModule(module); + if (lexer_.GetKind() == TokKind::kw_HloModule) { + // This means that the text contains a full HLO module. + if (!ParseHloModule(module)) { + return InvalidArgument( + "Syntax error when trying to parse the text as a HloModule:\n%s", + GetError()); + } + return Status::OK(); + } + // This means that the text is a single HLO instruction. + if (!ParseSingleInstruction(module)) { + return InvalidArgument( + "Syntax error when trying to parse the text as a single " + "HloInstruction:\n%s", + GetError()); + } + return Status::OK(); } std::pair* HloParser::FindInstruction( const string& name, const optional& shape) { - std::pair* instr = - tensorflow::gtl::FindOrNull(instruction_pool_, name); + std::pair* instr = nullptr; + if (!name.empty()) { + instr = tensorflow::gtl::FindOrNull(current_name_table(), name); + } + // Potentially call the missing instruction hook. - if (instr == nullptr && missing_instruction_hook_ != nullptr) { - return missing_instruction_hook_(name, shape); + if (instr == nullptr && create_missing_instruction_ != nullptr && + scoped_name_tables_.size() == 1) { + if (!shape.has_value()) { + Error(lexer_.GetLoc(), + "Operand had no shape in HLO text; cannot create parameter for " + "single-instruction module."); + return nullptr; + } + return create_missing_instruction_(name, *shape); } return instr; } @@ -439,7 +500,6 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { if (!ParseName(&name)) { return false; } - auto builder = absl::make_unique(name); LocTy shape_loc = nullptr; Shape shape; @@ -447,40 +507,21 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { return false; } - string root_name; - if (!ParseInstructionList(builder.get(), &root_name)) { + HloComputation* computation = nullptr; + if (!ParseInstructionList(&computation, name)) { return false; } - std::pair* root_node = FindInstruction(root_name); - // This means some instruction was marked as ROOT but we didn't find it in the - // pool, which should not happen. - if (!root_name.empty() && root_node == nullptr) { - LOG(FATAL) << "instruction " << root_name - << " was marked as ROOT but the parser has not seen it before"; - } - - HloInstruction* root = root_node == nullptr ? nullptr : root_node->first; - // Now root can be either an existing instruction or a nullptr. If it's a - // nullptr, the implementation of Builder will set the last instruction as - // root instruction. - computations_.emplace_back(builder->Build(root)); - HloComputation* computation = computations_.back().get(); - - if (!root) { - root = computation->root_instruction(); - } else { - CHECK_EQ(root, computation->root_instruction()); - } - // If param_list_to_shape was present, check compatibility. - if (shape_loc != nullptr && !ShapeUtil::Compatible(root->shape(), shape)) { + if (shape_loc != nullptr && + !ShapeUtil::Compatible(computation->root_instruction()->shape(), shape)) { return Error( shape_loc, - StrCat("Shape of computation ", name, ", ", - ShapeUtil::HumanString(shape), - ", is not compatible with that of its root instruction ", - root_name, ", ", ShapeUtil::HumanString(root->shape()))); + StrCat( + "Shape of computation ", name, ", ", ShapeUtil::HumanString(shape), + ", is not compatible with that of its root instruction ", + computation->root_instruction()->name(), ", ", + ShapeUtil::HumanString(computation->root_instruction()->shape()))); } if (is_entry_computation) { @@ -489,43 +530,62 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { } *entry_computation = computation; } - instruction_pool_.clear(); return AddComputation(name, computation, name_loc); } // instruction_list ::= '{' instruction_list1 '}' // instruction_list1 ::= (instruction)+ -bool HloParser::ParseInstructionList(HloComputation::Builder* builder, - string* root_name) { +bool HloParser::ParseInstructionList(HloComputation** computation, + const string& computation_name) { + Scope scope(&scoped_name_tables_); + HloComputation::Builder builder(computation_name); if (!ParseToken(TokKind::kLbrace, "expects '{' at the beginning of instruction list.")) { return false; } + string root_name; do { - if (!ParseInstruction(builder, root_name)) { + if (!ParseInstruction(&builder, &root_name)) { return false; } } while (lexer_.GetKind() != TokKind::kRbrace); - return ParseToken(TokKind::kRbrace, - "expects '}' at the end of instruction list."); + if (!ParseToken(TokKind::kRbrace, + "expects '}' at the end of instruction list.")) { + return false; + } + HloInstruction* root = nullptr; + if (!root_name.empty()) { + std::pair* root_node = + tensorflow::gtl::FindOrNull(current_name_table(), root_name); + + // This means some instruction was marked as ROOT but we didn't find it in + // the pool, which should not happen. + if (root_node == nullptr) { + LOG(FATAL) << "instruction " << root_name + << " was marked as ROOT but the parser has not seen it before"; + } + root = root_node->first; + } + + // Now root can be either an existing instruction or a nullptr. If it's a + // nullptr, the implementation of Builder will set the last instruction as + // the root instruction. + computations_.emplace_back(builder.Build(root)); + *computation = computations_.back().get(); + return true; } // instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)* bool HloParser::ParseInstruction(HloComputation::Builder* builder, string* root_name) { string name; - Shape shape; - HloOpcode opcode; - std::vector operands; - LocTy maybe_root_loc = lexer_.GetLoc(); bool is_root = EatIfPresent(TokKind::kw_ROOT); const LocTy name_loc = lexer_.GetLoc(); if (!ParseName(&name) || - !ParseToken(TokKind::kEqual, "expects '=' in instruction") || - !ParseShape(&shape) || !ParseOpcode(&opcode)) { + !ParseToken(TokKind::kEqual, "expects '=' in instruction")) { return false; } @@ -536,6 +596,19 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, *root_name = name; } + return ParseInstruciontRhs(builder, name, name_loc); +} + +bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, + const string& name, LocTy name_loc) { + Shape shape; + HloOpcode opcode; + std::vector operands; + + if (!ParseShape(&shape) || !ParseOpcode(&opcode)) { + return false; + } + // Add optional attributes. std::unordered_map attrs; optional sharding; @@ -766,8 +839,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kSort: { - auto loc = lexer_.GetLoc(); - optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; @@ -775,20 +846,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, dimensions->size() != 1) { return false; } - switch (operands.size()) { - case 1: - instruction = builder->AddInstruction(HloInstruction::CreateSort( - shape, dimensions->at(0), /*keys=*/operands[0])); - break; - case 2: - instruction = builder->AddInstruction(HloInstruction::CreateSort( - shape, dimensions->at(0), - /*keys=*/operands[0], /*values=*/operands[1])); - break; - default: - return Error(loc, StrCat("expects either 1 or 2 operands, but has ", - operands.size(), " operands")); - } + instruction = builder->AddInstruction(HloInstruction::CreateSort( + shape, dimensions->at(0), + /*keys=*/operands[0], + /*values=*/absl::Span(operands).subspan(1))); break; } case HloOpcode::kTuple: { @@ -1270,6 +1331,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional window; optional dnums; optional feature_group_count; + optional> operand_layout_constraints; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; attrs["opaque"] = {/*required=*/false, AttrTy::kString, &opaque}; @@ -1278,12 +1340,52 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, AttrTy::kConvolutionDimensionNumbers, &dnums}; attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, &feature_group_count}; + attrs["operand_layout_constraints"] = { + /*required=*/false, AttrTy::kShapeList, &operand_layout_constraints}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction( - HloInstruction::CreateCustomCall(shape, operands, *custom_call_target, - opaque.has_value() ? *opaque : "")); + if (operand_layout_constraints.has_value()) { + if (!LayoutUtil::HasLayout(shape)) { + return Error(lexer_.GetLoc(), + "Layout must be set on layout-constrained custom call"); + } + if (operands.size() != operand_layout_constraints->size()) { + return Error(lexer_.GetLoc(), + StrCat("Expected ", operands.size(), + " operand layout constraints, ", + operand_layout_constraints->size(), " given")); + } + for (int64 i = 0; i < operands.size(); ++i) { + const Shape& operand_shape_with_layout = + (*operand_layout_constraints)[i]; + if (!LayoutUtil::HasLayout(operand_shape_with_layout)) { + return Error(lexer_.GetLoc(), + StrCat("Operand layout constraint shape ", + ShapeUtil::HumanStringWithLayout( + operand_shape_with_layout), + " for operand ", i, " does not have a layout")); + } + if (!ShapeUtil::Compatible(operand_shape_with_layout, + operands[i]->shape())) { + return Error( + lexer_.GetLoc(), + StrCat( + "Operand layout constraint shape ", + ShapeUtil::HumanStringWithLayout(operand_shape_with_layout), + " for operand ", i, + " is not compatible with operand shape ", + ShapeUtil::HumanStringWithLayout(operands[i]->shape()))); + } + } + instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( + shape, operands, *custom_call_target, *operand_layout_constraints, + opaque.has_value() ? *opaque : "")); + } else { + instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( + shape, operands, *custom_call_target, + opaque.has_value() ? *opaque : "")); + } if (window.has_value()) { instruction->set_window(*window); } @@ -2146,7 +2248,20 @@ bool HloParser::ParseOperands(std::vector* operands) { } } if (!ParseName(&name)) { - return false; + // When parsing a single instruction (as opposed to a whole module), an + // HLO may have one or more operands with a shape but no name: + // + // foo = add(f32[10], f32[10]) + // + // create_missing_instruction_ is always non-null when parsing a single + // instruction, and is responsible for creating kParameter instructions + // for these operands. + if (shape.has_value() && create_missing_instruction_ != nullptr && + scoped_name_tables_.size() == 1) { + name = ""; + } else { + return false; + } } std::pair* instruction = FindInstruction(name, shape); @@ -2299,9 +2414,17 @@ bool HloParser::ParseAttributeHelper( return true; } case AttrTy::kHloComputation: { - HloComputation* result; - if (!ParseComputationName(&result)) { - return false; + HloComputation* result = nullptr; + if (lexer_.GetKind() == TokKind::kLbrace) { + // This means it is a nested computation. + if (!ParseInstructionList(&result, /*computation_name=*/"_")) { + return false; + } + } else { + // This means it is a computation name. + if (!ParseComputationName(&result)) { + return false; + } } static_cast*>(attr_out_ptr)->emplace(result); return true; @@ -2441,6 +2564,15 @@ bool HloParser::ParseAttributeHelper( ->emplace(result); return true; } + case AttrTy::kShapeList: { + std::vector result; + if (!ParseShapeList(&result)) { + return false; + } + static_cast>*>(attr_out_ptr) + ->emplace(result); + return true; + } } }(); if (!success) { @@ -2733,6 +2865,23 @@ bool HloParser::ParsePrecisionList( parse_and_add_item); } +// shapelist ::= '{' shapes '}' +// precision_elements +// ::= /*empty*/ +// ::= shape (',' shape)* +bool HloParser::ParseShapeList(std::vector* result) { + auto parse_and_add_item = [&]() { + Shape shape; + if (!ParseShape(&shape)) { + return false; + } + result->push_back(std::move(shape)); + return true; + }; + return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + parse_and_add_item); +} + // int64list ::= start int64_elements end // int64_elements // ::= /*empty*/ @@ -2740,23 +2889,15 @@ bool HloParser::ParsePrecisionList( bool HloParser::ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector* result) { - if (!ParseToken(start, StrCat("expects an int64 list starting with ", - TokKindToString(start)))) { - return false; - } - if (lexer_.GetKind() == end) { - // empty - } else { - do { - tensorflow::int64 i; - if (!ParseInt64(&i)) { - return false; - } - result->push_back(i); - } while (EatIfPresent(delim)); - } - return ParseToken( - end, StrCat("expects an int64 list to end with ", TokKindToString(end))); + auto parse_and_add_item = [&]() { + tensorflow::int64 i; + if (!ParseInt64(&i)) { + return false; + } + result->push_back(i); + return true; + }; + return ParseList(start, end, delim, parse_and_add_item); } bool HloParser::ParseList(const TokKind start, const TokKind end, @@ -2841,7 +2982,8 @@ bool HloParser::ParseShape(Shape* result) { } if (lexer_.GetKind() != TokKind::kShape) { - return TokenError("expects shape"); + return TokenError(absl::StrCat("expected shape, saw ", + TokKindToString(lexer_.GetKind()))); } *result = lexer_.GetShapeVal(); lexer_.Lex(); @@ -3134,7 +3276,7 @@ bool HloParser::EatIfPresent(TokKind kind) { bool HloParser::AddInstruction(const string& name, HloInstruction* instruction, LocTy name_loc) { - auto result = instruction_pool_.insert({name, {instruction, name_loc}}); + auto result = current_name_table().insert({name, {instruction, name_loc}}); if (!result.second) { Error(name_loc, StrCat("instruction already exists: ", name)); return Error(/*loc=*/result.first->second.second, @@ -3204,82 +3346,78 @@ StatusOr HloParser::ParsePaddingConfigOnly() { return padding_config; } -Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder, - string* root_name) { - TF_RET_CHECK(missing_instruction_hook_ == nullptr); +bool HloParser::ParseSingleInstruction(HloModule* module) { + if (create_missing_instruction_ != nullptr || !scoped_name_tables_.empty()) { + LOG(FATAL) << "Parser state is not clean. Please do not call any other " + "methods before calling ParseSingleInstruction."; + } + HloComputation::Builder builder(module->name()); // The missing instruction hook we register creates the shaped instruction on // the fly as a parameter and returns it. int64 parameter_count = 0; - missing_instruction_hook_ = - [this, builder, ¶meter_count]( - string name, - const optional& shape) -> std::pair* { - if (!shape.has_value()) { - Error(lexer_.GetLoc(), - StrCat("Operand ", name, - " had no shape in HLO text; cannot create parameter for " - "single-instruction module.")); - return nullptr; - } - HloInstruction* parameter = builder->AddInstruction( - HloInstruction::CreateParameter(parameter_count++, *shape, name)); - instruction_pool_[name] = {parameter, lexer_.GetLoc()}; - return tensorflow::gtl::FindOrNull(instruction_pool_, name); + create_missing_instruction_ = + [this, &builder, ¶meter_count]( + const string& name, + const Shape& shape) -> std::pair* { + string new_name = name.empty() ? StrCat("_", parameter_count) : name; + HloInstruction* parameter = builder.AddInstruction( + HloInstruction::CreateParameter(parameter_count++, shape, new_name)); + current_name_table()[new_name] = {parameter, lexer_.GetLoc()}; + return tensorflow::gtl::FindOrNull(current_name_table(), new_name); }; - // Prime the lexer. - lexer_.Lex(); - // Parse the instruction with the registered hook. - if (!ParseInstruction(builder, root_name)) { - return InvalidArgument("Syntax error:\n%s", GetError()); + Scope scope(&scoped_name_tables_); + if (CanBeShape()) { + // This means that the instruction's left-hand side is probably omitted, + // e.g. + // + // f32[10] fusion(...), calls={...} + if (!ParseInstruciontRhs(&builder, module->name(), lexer_.GetLoc())) { + return false; + } + } else { + // This means that the instruction's left-hand side might exist, e.g. + // + // foo = f32[10] fusion(...), calls={...} + string root_name; + if (!ParseInstruction(&builder, &root_name)) { + return false; + } } - return Status::OK(); + + module->AddEntryComputation(builder.Build()); + for (auto& comp : computations_) { + module->AddEmbeddedComputation(std::move(comp)); + } + return true; } } // namespace StatusOr> ParseHloString( absl::string_view str, const HloModuleConfig& config) { - auto module = absl::make_unique(/*name=*/"", config); + auto module = absl::make_unique(/*name=*/"_", config); HloParser parser(str); - if (!parser.Run(module.get())) { - return InvalidArgument("Syntax error:\n%s", parser.GetError()); - } + TF_RETURN_IF_ERROR(parser.Run(module.get())); return std::move(module); } StatusOr> ParseHloString(absl::string_view str) { - auto module = absl::make_unique(/*name=*/"", HloModuleConfig()); + auto module = absl::make_unique(/*name=*/"_", HloModuleConfig()); HloParser parser(str); - if (!parser.Run(module.get())) { - return InvalidArgument("Syntax error:\n%s", parser.GetError()); - } + TF_RETURN_IF_ERROR(parser.Run(module.get())); return std::move(module); } Status ParseHloString(absl::string_view str, HloModule* module) { TF_RET_CHECK(module->computation_count() == 0); HloParser parser(str); - if (!parser.Run(module)) { - return InvalidArgument("Syntax error:\n%s", parser.GetError()); - } + TF_RETURN_IF_ERROR(parser.Run(module)); return Status::OK(); } -StatusOr> ParseHloOpToModule( - absl::string_view str, absl::string_view name) { - HloParser parser(str); - auto builder = absl::make_unique(string(name)); - string root_name; - TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name)); - std::unique_ptr computation = builder->Build(); - auto module = absl::make_unique(string(name), HloModuleConfig()); - module->AddEntryComputation(std::move(computation)); - return std::move(module); -} - StatusOr ParseSharding(absl::string_view str) { HloParser parser(str); return parser.ParseShardingOnly(); diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index 369603551463fd4b4911b393f3c6c2b36f0e4bbb..81eeb9f13bf7f06123c0b35e9f3352c197866a7a 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -40,11 +40,6 @@ StatusOr> ParseHloString( // point to an empty module (no computations). Status ParseHloString(absl::string_view str, HloModule* module); -// Parses the text for a single HLO operation into an HLO module with a function -// that runs that operation (with the same parameters) as its entry computation. -StatusOr> ParseHloOpToModule( - absl::string_view str, absl::string_view name = "single_op"); - // Given a string in the HloModule::ToString() format, parses the string and // creates a HloModule with default config. StatusOr> ParseHloString(absl::string_view str); diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 96db96bdb96359ac8b694c13eabd528157701c99..19f84d8bd28371518e44e38614b8a81fa920985f 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -802,6 +802,43 @@ ENTRY %ConstantUnsignedNoOverflow () -> u64[] { ROOT %constant = u64[] constant(9223372036854775807) } +)" +}, +// CustomCallWithLayoutConstraints +{ +"CustomCallWithLayoutConstraints", +R"(HloModule CustomCallWithLayoutConstraints + +ENTRY %CustomCallWithLayoutConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] { + %p0 = f32[42,2,3]{0,1,2} parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[123,4]{1,0}} +} + +)" +}, +// CustomCallWithLayoutConstraintsNoOperands +{ +"CustomCallWithLayoutConstraintsNoOperands", +R"(HloModule CustomCallWithLayoutConstraintsNoOperands + +ENTRY %CustomCallWithLayoutConstraints () -> f32[1,2,3] { + ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(), custom_call_target="baz", operand_layout_constraints={} +} + +)" +}, +// CustomCallWithLayoutConstraintsTupleShapes +{ +"CustomCallWithLayoutConstraintsTupleShapes", +R"(HloModule CustomCallWithLayoutConstraintsTupleShapes + +ENTRY %CustomCallWithLayoutConstraints (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[1,2,3], f32[1,2,3]) { + %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={(f32[2,2]{1,0}, f32[42,2,3]{2,0,1}), f32[123,4]{1,0}} +} + )" }, }); @@ -966,6 +1003,21 @@ ENTRY Sort { ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}) sort(keys, values), dimensions={0} } +)" +}, +// Sort (Key, Value, Value, Value) +{ +"SortManyValues", +R"(HloModule sort + +ENTRY Sort { + keys = f32[1024,16]{0,1} parameter(0) + values.0 = s32[1024,16]{0,1} parameter(1) + values.1 = u32[1024,16]{0,1} parameter(2) + values.2 = f32[1024,16]{0,1} parameter(3) + ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0} +} + )" }, // Conditional @@ -1163,49 +1215,80 @@ ENTRY Sort { // clang-format on } -class HloParserTest : public ::testing::Test, - public ::testing::WithParamInterface { +// The test class for those tests defined above which round-trip through the +// parser and ToString is templatized on two bool parameters: +// +// short_form : used for the "short" test cases which use the ShortParsable +// output form. +// proto_round_trip : whether the module should also be round-tripped through +// HloProto form. This provides much better coverage for the proto +// serialization/deserialization. +// +// The proto_round_trip=true case also technically covers the Parser->ToString +// roundtrip as well, but separating out the Parser->ToString roundtrip as its +// own test provides better isolation and could conceivably catch weirdo bugs +// which are hidden by interaction between the textual and proto roundtripping. +template +class HloParameterizedParserTest + : public ::testing::Test, + public ::testing::WithParamInterface { protected: - static void ExpectHasSubstr(string_view s, string_view expected) { - EXPECT_TRUE(absl::StrContains(s, expected)) - << "'" << s << "' does not contain '" << expected << "'"; - } - // Expects "ToString(ParseHloString(string)) == string", that is, parses the // string, asserts that it succeeded, stringifies the parsed module, and // checks that the it equals the original string. void ExpectEqual() { const string& original = GetParam().module_string; - auto result = ParseHloString(original); - TF_ASSERT_OK(result.status()); - EXPECT_EQ(original, result.ValueOrDie()->ToString( - HloPrintOptions().set_print_large_constants(true))); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(original)); + if (proto_round_trip) { + TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto( + module->ToProto(), module->config())); + } + if (short_form) { + EXPECT_EQ(original, module->ToString(HloPrintOptions::ShortParsable())); + } else { + EXPECT_EQ( + original, + module->ToString(HloPrintOptions().set_print_large_constants(true))); + } } }; -class HloParserShortTest : public HloParserTest { - protected: - void ExpectEqualShort() { - const string& original = GetParam().module_string; - auto result = ParseHloString(original); - TF_ASSERT_OK(result.status()); - EXPECT_EQ(original, - result.ValueOrDie()->ToString(HloPrintOptions::ShortParsable())); - } -}; - -TEST_P(HloParserTest, Run) { ExpectEqual(); } +// These using shenanigans are required because the TEST_P macro doesn't like +// template instantiations which contain commas. +using HloParserTestLong = HloParameterizedParserTest; +using HloParserTestLongProto = HloParameterizedParserTest; +using HloParserTestShort = HloParameterizedParserTest; +using HloParserTestShortProto = HloParameterizedParserTest; -TEST_P(HloParserShortTest, Run) { ExpectEqualShort(); } +TEST_P(HloParserTestLong, Run) { ExpectEqual(); } +TEST_P(HloParserTestLongProto, Run) { ExpectEqual(); } +TEST_P(HloParserTestShort, Run) { ExpectEqual(); } +TEST_P(HloParserTestShortProto, Run) { ExpectEqual(); } -INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest, +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTestLong, ::testing::ValuesIn(CreateTestCases()), TestDataToString); - -INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest, +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, + HloParserTestLongProto, + ::testing::ValuesIn(CreateTestCases()), + TestDataToString); +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTestShort, + ::testing::ValuesIn(CreateShortTestCases()), + TestDataToString); +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, + HloParserTestShortProto, ::testing::ValuesIn(CreateShortTestCases()), TestDataToString); +class HloParserTest : public ::testing::Test { + protected: + static void ExpectHasSubstr(string_view s, string_view expected) { + EXPECT_TRUE(absl::StrContains(s, expected)) + << "'" << s << "' does not contain '" << expected << "'"; + } +}; + TEST_F(HloParserTest, Empty) { const string original = ""; auto result = ParseHloString(original); @@ -1273,7 +1356,7 @@ TEST_F(HloParserTest, MoreConstants) { ENTRY %SelectScalarS32True.v4 () -> s32[] { %constant.2 = pred[] constant(true) - %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,3]1,2,3,4} + %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,2]1,2,3,4} %constant = s32[] constant(42) %select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant) } @@ -1732,6 +1815,25 @@ ENTRY entry { "was parsing 8:39: error: instruction does not exist: aparam"); } +TEST_F(HloParserTest, SameNameDiffComputations) { + const string original = R"(HloModule same_names: +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT result = f32[] add(p0, p1) +} + +ENTRY ReduceR3ToR2 { + p0 = f32[8,16,256]{2,1,0} parameter(0) + p1 = f32[] constant(0) + ROOT result = f32[8,16]{1,0} reduce(p0, p1), dimensions={2}, to_apply=add +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(original)); + ASSERT_NE(module->entry_computation(), nullptr); + EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce()); +} + TEST_F(HloParserTest, ParseSharding) { const string original = "{maximal device=42}"; TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original)); @@ -1785,27 +1887,142 @@ TEST(HloParserSingleOpTest, SingleOp) { const string text = "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, " "f32[2,4]{1,0} %x)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), op::Multiply(op::Parameter(0), op::Parameter(1))); } -TEST(HloParserSingleOpTest, SingleOpNoShapesProducesError) { +TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) { + const string text = "multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)"; + StatusOr> module = ParseHloString(text); + ASSERT_TRUE(!module.status().ok()); + LOG(INFO) << "Status: " << module.status(); + EXPECT_THAT(module.status().ToString(), + ::testing::HasSubstr("expects '=' in instruction")); +} + +TEST(HloParserSingleOpTest, SingleOpNoOperandShapesProducesError) { const string text = "%multiply = f32[2,4]{1,0} multiply(%broadcast, %x)"; - StatusOr> module = ParseHloOpToModule(text); + StatusOr> module = ParseHloString(text); ASSERT_TRUE(!module.status().ok()); LOG(INFO) << "Status: " << module.status(); - EXPECT_THAT( - module.status().ToString(), - ::testing::HasSubstr("Operand broadcast had no shape in HLO text")); + EXPECT_THAT(module.status().ToString(), + ::testing::HasSubstr("Operand had no shape in HLO text")); +} + +TEST(HloParserSingleOpTest, SingleOpNoNames) { + const string text = + "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Parameter(0), op::Parameter(1))); +} + +TEST(HloParserSingleOpTest, CanonicalOp) { + const string text = "f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Parameter(0), op::Parameter(1))); + EXPECT_EQ( + computation->root_instruction()->ToString(HloPrintOptions::Canonical()), + text); +} + +TEST(HloParserSingleOpTest, CanonicalOpWithNested) { + const string text = + R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +}, body= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_EQ( + computation->root_instruction()->ToString(HloPrintOptions::Canonical()), + text); +} + +TEST(HloParserSingleOpTest, SingleOpWithNested) { + const string text = + R"(%fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %p0, f32[2]{0} %p1), kind=kLoop, calls= +{ + %param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0) + %param_1 = f32[2]{0} parameter(1) + %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %param_1), dimensions={1} + ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %param_0, f32[3,2,1,1]{3,2,1,0} %broadcast) +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_THAT(computation->root_instruction(), + op::Fusion(op::Parameter(0), op::Parameter(1))); +} + +TEST(HloParserSingleOpTest, SingleOpWithNested_DoesNotExist) { + const string text = + R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply= +{ + result = f32[] add(f32[] x, f32[] y) +})"; + auto status = ParseHloString(text).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("does not exist: x")); +} + +TEST(HloParserSingleOpTest, SingleOpWithNested_NoLhs) { + const string text = + R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply= +{ + f32[] add(f32[] x, f32[] y) +})"; + auto status = ParseHloString(text).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name")); +} + +TEST(HloParserSingleOpTest, SingleOpWithNested_NoOperandName) { + const string text = + R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply= +{ + result = f32[] add(f32[], f32[]) +})"; + auto status = ParseHloString(text).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name")); } TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) { const string text = R"(%convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), @@ -1904,5 +2121,47 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { op::Broadcast(), op::Multiply(), op::Add())); } +TEST_F(HloParserTest, CustomCallWrongNumberofOperandConstraints) { + const string original = R"(HloModule CustomCallWrongNumberofOperandConstraints + +ENTRY %CustomCallWrongNumberofOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] { + %p0 = f32[42,2,3]{0,1,2} parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}} +} + +)"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "Expected 2 operand layout constraints, 1 given"); +} + +TEST_F(HloParserTest, CustomCallIncompatibleOperandConstraints) { + const string original = R"(HloModule CustomCallIncompatibleOperandConstraints + +ENTRY %CustomCallIncompatibleOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] { + %p0 = f32[42,2,3]{0,1,2} parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[555,5]{1,0}} +} + +)"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "operand 1 is not compatible with operand shape"); +} + +TEST_F(HloParserTest, AllowShapeWhitespace) { + const string text = R"( +HloModule module + +ENTRY entry { + ROOT root = f32[ 1, 2,3, 4, 5]{0, 1, 2,3, 4 } parameter(0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); +} + +// custom call incompatible shape. + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 8c2f928ca101fae8e63663705554ae626c863bf6..5e004ce78ac1fd6da18ab2a54d23ef27e9586cf6 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" @@ -24,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -74,8 +75,8 @@ StatusOr HloPassPipeline::RunPassesInternal( std::vector HloPassPipeline::GetEnabledPasses( const DebugOptions& debug_options) { auto repeated_field = debug_options.xla_disable_hlo_passes(); - tensorflow::gtl::FlatSet disabled_pass_names(repeated_field.begin(), - repeated_field.end()); + absl::flat_hash_set disabled_pass_names(repeated_field.begin(), + repeated_field.end()); if (!disabled_pass_names.empty()) { VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " << absl::StrJoin(disabled_pass_names, ", "); @@ -98,7 +99,7 @@ void HloPassPipeline::MaybeDumpHlo(const HloModule& module, if (!proto_dump_path.empty()) { static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); static auto* const module_id_to_pass_number = - new tensorflow::gtl::FlatMap(); + new absl::flat_hash_map(); tensorflow::mutex_lock lock(mu); const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++; diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index b9c0b0c4ee1957fce48641230cef6391bcc9180e..cf33668f5bfa64a7843efc76e9f6768d18533240 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_proto_util.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" #include @@ -36,17 +37,28 @@ HloProto MakeHloProto(const HloModule& module) { return proto; } +StatusOr> CreateModuleFromProto( + const HloModuleProto& proto, const HloModuleConfig& module_config) { + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + HloModule::CreateFromProto(proto, module_config)); + TF_RETURN_IF_ERROR( + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) + .Run(module.get()) + .status()); + return std::move(module); +} + StatusOr> EntryComputationParameterShapes( const HloProto& hlo_proto) { if (!hlo_proto.has_hlo_module()) { return NotFound("HloProto missing HloModuleProto."); } - if (!hlo_proto.hlo_module().has_program_shape()) { + if (!hlo_proto.hlo_module().has_host_program_shape()) { return NotFound("HloProto missing program shape."); } std::vector parameter_shapes; - const auto& program_shape = hlo_proto.hlo_module().program_shape(); + const auto& program_shape = hlo_proto.hlo_module().host_program_shape(); for (const Shape& shape : program_shape.parameters()) { parameter_shapes.push_back(&shape); } @@ -57,14 +69,14 @@ StatusOr EntryComputationOutputShape(const HloProto& hlo_proto) { if (!hlo_proto.has_hlo_module()) { return NotFound("HloProto missing HloModuleProto."); } - if (!hlo_proto.hlo_module().has_program_shape()) { + if (!hlo_proto.hlo_module().has_host_program_shape()) { return NotFound("HloProto missing program shape."); } - if (!hlo_proto.hlo_module().program_shape().has_result()) { + if (!hlo_proto.hlo_module().host_program_shape().has_result()) { return NotFound("HloProto missing result in its program shape"); } - return &hlo_proto.hlo_module().program_shape().result(); + return &hlo_proto.hlo_module().host_program_shape().result(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.h b/tensorflow/compiler/xla/service/hlo_proto_util.h index 3d9c375cd5d26f92cf8316f78789daf4fc08c927..1db82dd6fcaa5d7fe7d65894c1021105f0b26266 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.h +++ b/tensorflow/compiler/xla/service/hlo_proto_util.h @@ -35,6 +35,12 @@ HloProto MakeHloProto(const HloModule& module, // will not be included in the output. HloProto MakeHloProto(const HloModule& module); +// Create an HLO state from serialized representation. In addition to +// creating the proto with HloModule::CreateFromProto(...) it also +// uses HloVerifier to ensure basic invariants are held. +StatusOr> CreateModuleFromProto( + const HloModuleProto& proto, const HloModuleConfig& module_config); + // Returns the shapes of the parameters of the entry computation. Shape pointers // refer to shapes inside of the given HloProto. StatusOr> EntryComputationParameterShapes( diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index 2a07b6fcbc243d955e136ccdf097c8155a115845..2d5197be9e6f69f698729e06b7506a5bc6260bcd 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -24,7 +24,7 @@ namespace hlo_query { bool IsConstantR0F32(HloInstruction* instruction, float* out) { if (instruction->opcode() == HloOpcode::kConstant && - ShapeUtil::IsScalarF32(instruction->shape())) { + ShapeUtil::IsScalarWithElementType(instruction->shape(), F32)) { *out = instruction->literal().Get({}); return true; } diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index b66a2aa4bd2b00a88cdbfa6b41c9123bb370aa87..5a5f01f8fd647c74217c80ce4a7633b8957e335f 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -19,11 +19,11 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -154,7 +154,7 @@ class HloReachabilityMap { // Dense assignment from HloInstruction* to number. These numbers index // into the bit_vectors_ vector and into the bits within a BitVector. - tensorflow::gtl::FlatMap indices_; + absl::flat_hash_map indices_; // Bitvectors holding the reachability to each instruction. The bit vector for // instruction X includes ones for each instruction which X is reachable from. diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index a43867193628d05ad7703a5d5ed8bdc9c72de581..49e46ecd00ee4370f3e93746348373b79febed3d 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -75,7 +77,7 @@ bool IsRematerializable(const HloInstruction* instruction) { // cache before, and eventually calling the IsRematerializable() API. bool CanBeRematerialized( const HloInstruction* instruction, - tensorflow::gtl::FlatMap* remat_able) { + absl::flat_hash_map* remat_able) { auto it = remat_able->find(instruction); if (it != remat_able->end()) { return it->second; @@ -268,7 +270,7 @@ class InstructionList { Item* first_; // Item for each instruction. - tensorflow::gtl::FlatMap item_map_; + absl::flat_hash_map item_map_; }; // Return the items which use the given LogicalBuffer. Sets @@ -503,7 +505,7 @@ MemoryUsageTracker::MemoryUsageTracker( PointsToSet::BufferSet live_out_set = points_to_analysis.GetPointsToSet(computation_->root_instruction()) .CreateFlattenedSet(); - tensorflow::gtl::FlatMap + absl::flat_hash_map logical_buffer_to_buffer_id; for (auto* item = instruction_list_.first(); item != nullptr; @@ -854,7 +856,7 @@ int64 RematerializationCost(const HloInstruction* instruction, Item* PickRematerializationCandidate( const MemoryUsageTracker& memory_tracker, const InstructionList& instruction_list, int64 memory_limit_bytes, - tensorflow::gtl::FlatMap* remat_able) { + absl::flat_hash_map* remat_able) { Item* best_item = nullptr; int64 best_cost = 0; @@ -980,10 +982,10 @@ StatusOr HloRematerialization::RematerializeComputation( // rematerialization is essentially a move). If the next rematerialization of // the instruction is also a move then the rematerialization is added to the // blacklist. - tensorflow::gtl::FlatSet remat_move_instructions; + absl::flat_hash_set remat_move_instructions; // The map from instructions to their rematerializable status. - tensorflow::gtl::FlatMap remat_able; + absl::flat_hash_map remat_able; // The peak memory of the computation at any point in the instruction // sequence. @@ -1213,7 +1215,7 @@ StatusOr HloRematerialization::Run(HloModule* module) { // by the caller. int64 module_output_size = 0; ShapeUtil::ForEachSubshape( - module->entry_computation()->root_instruction()->shape(), + module->result_shape(), [&module_output_size, this](const Shape& subshape, const ShapeIndex& /*index*/) { module_output_size += size_function_(subshape); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 7330d73c09eb5aa8265fa5753a2de5885f51bf15..70d83c04f07ca7fd0139f586869e8fe688f958f4 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -15,6 +15,8 @@ #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -115,14 +117,13 @@ class HloRematerialization : public HloModulePass { // computations called from sequential context // (CallContext::kSequential). These values are updated as rematerialization // occurs. - tensorflow::gtl::FlatMap - computation_peak_memory_; + absl::flat_hash_map computation_peak_memory_; std::unique_ptr points_to_analysis_; // Set of computations which have had rematerialization // applied. Rematerialization is only applied once per computation. - tensorflow::gtl::FlatSet rematerialized_computations_; + absl::flat_hash_set rematerialized_computations_; // Count of the total instructions rematerialized. int64 instructions_rematerialized_ = 0; diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index fa7f216321988137dcf9104a324f5f7789869aa5..3f0ca342b4c84216ddd5ee553848360d8bd1ff0b 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/hlo_module_group.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -324,10 +325,13 @@ StatusOr> HloRunner::ExecuteReplicated( StatusOr> HloRunner::CreateExecutable( std::unique_ptr module, bool run_hlo_passes) { if (run_hlo_passes) { + auto module_group = absl::make_unique(std::move(module)); TF_ASSIGN_OR_RETURN( - module, backend().compiler()->RunHloPasses( - std::move(module), backend().default_stream_executor(), - backend().memory_allocator())); + auto executables, + backend().compiler()->Compile(std::move(module_group), + {{backend().default_stream_executor()}}, + backend().memory_allocator())); + return std::move(executables[0]); } return backend().compiler()->RunBackend(std::move(module), backend().default_stream_executor(), diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc index 3fc5dbeb02a26134a7f255fa0b6ebda1dc41ce4d..0778ff52174ef89c476950f2c830268a63888382 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" @@ -30,7 +32,7 @@ namespace xla { /* static */ StatusOr HloSchedule::CreateFromProto( const HloModule* module, const HloScheduleProto& proto) { - tensorflow::gtl::FlatMap id_to_computation; + absl::flat_hash_map id_to_computation; for (const HloComputation* computation : module->computations()) { id_to_computation[computation->unique_id()] = computation; } @@ -44,7 +46,7 @@ namespace xla { << "No computation exists in HLO module with id " << computation_id; const HloComputation* computation = comp_it->second; - tensorflow::gtl::FlatMap id_to_instruction; + absl::flat_hash_map id_to_instruction; for (const HloInstruction* instruction : computation->instructions()) { id_to_instruction[instruction->unique_id()] = instruction; } @@ -112,13 +114,13 @@ Status HloSchedule::UpdateComputationSchedule( const HloComputation* computation) { // Map from unique ID to HloInstruction pointer for instructions in the // computation. - tensorflow::gtl::FlatMap id_to_instruction; + absl::flat_hash_map id_to_instruction; for (const HloInstruction* instruction : computation->instructions()) { InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction); } // Set of all HloInstructions in the schedule. - tensorflow::gtl::FlatSet ids_in_schedule; + absl::flat_hash_set ids_in_schedule; for (int id : sequences_.at(computation->unique_id()).ids()) { InsertOrDie(&ids_in_schedule, id); } @@ -126,15 +128,13 @@ Status HloSchedule::UpdateComputationSchedule( // Map from HloInstruction X to newly added instructions (instruction is in // computation, but not in schedule) which use X. If an instruction is not in // the map, then it has no users which are newly added instructions. - tensorflow::gtl::FlatMap> + absl::flat_hash_map> new_instruction_uses; // For each newly added instruction, this is the count of the instruction's // operands that have not yet been scheduled. When this value reaches zero, // then the instruction may be placed in the schedule. - tensorflow::gtl::FlatMap - unscheduled_operand_count; + absl::flat_hash_map unscheduled_operand_count; // Create a worklist of newly added instructions which are ready to be added // to the schedule. Initialize worklist with those that have zero operands. @@ -211,15 +211,15 @@ Status HloSchedule::Update() { if (sequences_.size() > nonfusion_computations.size()) { // Schedule contains some computations which have been removed from the // HloModule. Remove them from the schedule as well. - tensorflow::gtl::FlatSet nonfusion_computations_ids; + absl::flat_hash_set nonfusion_computations_ids; for (const HloComputation* computation : nonfusion_computations) { nonfusion_computations_ids.insert(computation->unique_id()); } for (auto it = sequences_.begin(); it != sequences_.end();) { if (nonfusion_computations_ids.count(it->first) == 0) { - it = sequences_.erase(it); + sequences_.erase(it++); } else { - it++; + ++it; } } } @@ -235,7 +235,6 @@ Status HloSchedule::Update() { Status HloSchedule::Verify() const { VLOG(2) << "VerifySchedule()"; - XLA_VLOG_LINES(3, module_->ToString()); XLA_VLOG_LINES(2, ToString()); // Verify schedule contains exactly the same set of non-fusion computations as @@ -254,7 +253,7 @@ Status HloSchedule::Verify() const { // For each computation verify the set of instructions is the same and that // each dependency and control edge is honored. for (const HloComputation* computation : nonfusion_computations) { - tensorflow::gtl::FlatMap instruction_position; + absl::flat_hash_map instruction_position; int pos = 0; for (const HloInstruction* instruction : sequence(computation).instructions()) { diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h index 270fe6039f0afd119c76086de9a0596e0560e93e..0a714101ee587aa847fa674bbde5586287c51f33 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.h +++ b/tensorflow/compiler/xla/service/hlo_schedule.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -103,8 +104,7 @@ class HloSchedule { // Returns a map from HloComputation unique ID to instruction sequence. The // map contains all sequences in the schedule. - const tensorflow::gtl::FlatMap& sequences() - const { + const absl::flat_hash_map& sequences() const { return sequences_; } @@ -148,7 +148,7 @@ class HloSchedule { // A map from computation unique ID to instruction sequence. Unique IDs are // used rather than HloComputation pointers because HLO pointers are not // unique across HLO transformations because pointers may be recycled. - tensorflow::gtl::FlatMap sequences_; + absl::flat_hash_map sequences_; }; std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index de7e6b53d4d2aa88e2213248370b4da82bdeadeb..70a860c356ca2fb1c4c973ea3d96c50fabc2c7c2 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/overflow_util.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { @@ -369,10 +370,28 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, return HloSharding(tuple_shardings); } else if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) { return Replicate(); - } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL || - proto.tile_assignment_devices().size() == 1) { + } else if (proto.tile_assignment_devices().size() == 1) { return HloSharding(proto.tile_assignment_devices(0)); } + + TF_RET_CHECK(proto.type() != OpSharding::Type::OpSharding_Type_MAXIMAL) + << "Maximal sharding is expected to have single device assignment, but " + << proto.tile_assignment_devices().size() << " has provided."; + + TF_RET_CHECK(proto.tile_assignment_devices().size() > 1); + TF_RET_CHECK(!proto.tile_assignment_dimensions().empty()); + + // RE: the product of tile assignment tensor dimensions must be + // equal to tile_assignment_devices.size(). + int64 product_of_dimensions = 1; + for (auto dimension : proto.tile_assignment_dimensions()) { + TF_RET_CHECK(dimension > 0); + product_of_dimensions = + MultiplyWithoutOverflow(product_of_dimensions, dimension); + TF_RET_CHECK(product_of_dimensions > 0); + } + TF_RET_CHECK(product_of_dimensions == proto.tile_assignment_devices().size()); + // Some versions of gcc cannot infer the TileAssignment constructor from a // braced initializer-list, so create one manually. std::vector devices(proto.tile_assignment_devices().begin(), @@ -450,6 +469,9 @@ absl::optional HloSharding::ExtractSingleSharding() const { if (!IsTuple()) { return *this; } + if (tuple_elements_.empty()) { + return absl::nullopt; + } for (int64 i = 1; i < tuple_elements_.size(); ++i) { if (tuple_elements_[0] != tuple_elements_[i]) { return absl::nullopt; diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index e3f4a9852ace86c20610362aa6ad3c3d9c78de30..88329c899794a6e0f5102d181d6161fe17f89932 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -169,14 +169,14 @@ Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain, // If user is a tuple instruction, return the tuple subsharding corresponding to // the operand matching the instruction argument, because that is the // subsharding corresponding to instruction. -ShapeTree GetShardingTreeFromUser( +StatusOr> GetShardingTreeFromUser( const HloInstruction& instruction, const HloInstruction& user) { if (user.opcode() == HloOpcode::kTuple) { return user.sharding() .GetSubSharding(user.shape(), {user.operand_index(&instruction)}) - .GetAsShapeTree(instruction.shape()); + .AsShapeTree(instruction.shape()); } - return user.sharding().GetAsShapeTree(user.shape()); + return user.sharding().AsShapeTree(user.shape()); } // Assign rhs to lhs. If rhs is unassigned (assigned to kUnassignedDevice) @@ -264,8 +264,8 @@ StatusOr ApplyShardingFromUsers(HloInstruction* instruction, continue; } AssignmentKind sub_assigned = AssignmentKind::kUnassigned; - ShapeTree user_sharding_tree = - GetShardingTreeFromUser(*instruction, *user); + TF_ASSIGN_OR_RETURN(ShapeTree user_sharding_tree, + GetShardingTreeFromUser(*instruction, *user)); if (ShapeUtil::IsTuple(instruction->shape())) { // For tuple-shaped instructions collect individual tuple subshardings // from the uses, and then combine them into the tuple sharding. diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 85494877023fa3812973e993a349a7559706ab5d..59594ab2f0f70a206c73e998dbfa69c2c5c7ba43 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -167,7 +167,7 @@ void HloValue::SetPositionsAndComputeUses( positions_.insert(positions_.end(), positions.begin(), positions.end()); // Gather the computation roots at which this value appears. - tensorflow::gtl::FlatSet root_positions; + absl::flat_hash_set root_positions; for (const HloPosition& position : positions_) { if (position.instruction == position.instruction->parent()->root_instruction()) { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 6eb66589048c1a8d6ccfed73c0f7e32f5fe6e568..a1cb60a0499847080c4c151c2828fa08e76c29d2 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" @@ -23,10 +24,26 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { +Status ShapeVerifier::Preprocess(HloInstruction* hlo) { + if (LayoutUtil::IsSparseArray(hlo->shape())) { + return InternalError("Sparse arrays are not yet fully supported: %s", + hlo->ToString()); + } + return Status::OK(); +} + +static Status CheckOperandCount(const HloInstruction* hlo, int expected) { + if (hlo->operand_count() != expected) { + return InternalError("Expected %d operands for %s instruction: %s", + expected, HloOpcodeString(hlo->opcode()), + hlo->ToString()); + } + return Status::OK(); +} + Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) { return CheckUnaryShape(hlo); } @@ -58,12 +75,14 @@ Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) { } Status ShapeVerifier::HandleConvert(HloInstruction* convert) { + TF_RETURN_IF_ERROR(CheckOperandCount(convert, 1)); return CheckShape(convert, ShapeInference::InferConvertShape( convert->operand(0)->shape(), convert->shape().element_type())); } Status ShapeVerifier::HandleBitcastConvert(HloInstruction* convert) { + TF_RETURN_IF_ERROR(CheckOperandCount(convert, 1)); return CheckShape(convert, ShapeInference::InferBitcastConvertShape( convert->operand(0)->shape(), convert->shape().element_type())); @@ -74,6 +93,7 @@ Status ShapeVerifier::HandleCopy(HloInstruction* copy) { } Status ShapeVerifier::HandleDot(HloInstruction* dot) { + TF_RETURN_IF_ERROR(CheckOperandCount(dot, 2)); TF_ASSIGN_OR_RETURN(const Shape expected, ShapeInference::InferDotOpShape( dot->operand(0)->shape(), dot->operand(1)->shape(), @@ -82,6 +102,7 @@ Status ShapeVerifier::HandleDot(HloInstruction* dot) { } Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { + TF_RETURN_IF_ERROR(CheckOperandCount(convolution, 2)); TF_ASSIGN_OR_RETURN( const Shape expected, ShapeInference::InferConvolveShape( @@ -92,6 +113,7 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { } Status ShapeVerifier::HandleFft(HloInstruction* fft) { + TF_RETURN_IF_ERROR(CheckOperandCount(fft, 1)); TF_ASSIGN_OR_RETURN( const Shape expected, ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(), @@ -118,11 +140,13 @@ Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { } Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 1)); return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape( hlo->operand(0)->shape())); } Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { + TF_RETURN_IF_ERROR(CheckOperandCount(reduce_precision, 1)); return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( reduce_precision->operand(0)->shape(), reduce_precision->exponent_bits(), @@ -156,6 +180,7 @@ Status ShapeVerifier::CheckOperandAndParameter( } Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 1)); HloInfeedInstruction* infeed = Cast(instruction); TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); @@ -166,6 +191,7 @@ Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { } Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2)); HloOutfeedInstruction* outfeed = Cast(instruction); TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1)); @@ -192,10 +218,7 @@ bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0, } Status ShapeVerifier::HandleRng(HloInstruction* instruction) { - if (instruction->operand_count() != 2) { - return InternalError("Expected two operands for Rng instruction: %s", - instruction->ToString()); - } + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2)); const Shape& shape_0 = instruction->operand(0)->shape(); const Shape& shape_1 = instruction->operand(1)->shape(); @@ -244,29 +267,42 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) { } Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { + TF_RETURN_IF_ERROR(CheckOperandCount(reverse, 1)); return CheckShape( reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(), reverse->dimensions())); } Status ShapeVerifier::HandleSort(HloInstruction* sort) { - if (sort->operand_count() == 2 && - !ShapeUtil::SameDimensions(sort->operand(0)->shape(), - sort->operand(1)->shape())) { - return InternalError( - "Expected sort to have to have the same dimensions for the keys and " - "the values. Keys shape is: %s\n, Values shape is: %s", - StringifyShape(sort->operand(0)->shape()), - StringifyShape(sort->operand(1)->shape())); + if (sort->operand_count() < 1) { + return InternalError("Expected at least 1 operand for %s instruction: %s", + HloOpcodeString(sort->opcode()), sort->ToString()); + } + for (int64 operand = 1; operand < sort->operand_count(); ++operand) { + if (!ShapeUtil::SameDimensions(sort->operand(0)->shape(), + sort->operand(operand)->shape())) { + return InternalError( + "Expected sort to have to have the same dimensions for the keys " + "and the values. Keys shape is: %s\n, Values shape (operand index " + "%lld) is: %s", + StringifyShape(sort->operand(0)->shape()), operand, + StringifyShape(sort->operand(operand)->shape())); + } } return CheckVariadicShape(sort); } Status ShapeVerifier::HandleConstant(HloInstruction* constant) { + TF_RETURN_IF_ERROR(CheckOperandCount(constant, 0)); + if (!Cast(constant)->HasLiteral()) { + return InternalError("Constant is required to have a valid literal: %s", + constant->ToString()); + } return CheckShape(constant, constant->literal().shape()); } Status ShapeVerifier::HandleIota(HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 0)); auto* iota = Cast(instruction); const int64 rank = ShapeUtil::Rank(iota->shape()); if (rank == 0) { @@ -281,6 +317,7 @@ Status ShapeVerifier::HandleIota(HloInstruction* instruction) { } Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { + TF_RETURN_IF_ERROR(CheckOperandCount(get_tuple_element, 1)); return CheckShape(get_tuple_element, ShapeInference::InferGetTupleElementShape( get_tuple_element->operand(0)->shape(), @@ -288,6 +325,12 @@ Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { } Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { + if (reduce->operand_count() % 2 != 0) { + return InternalError( + "Expected an even number of operands for %s instruction: %s", + HloOpcodeString(reduce->opcode()), reduce->ToString()); + } + std::vector operand_shapes; for (const HloInstruction* operand : reduce->operands()) { operand_shapes.push_back(&operand->shape()); @@ -298,10 +341,12 @@ Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { } Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { + TF_RETURN_IF_ERROR(CheckOperandCount(bitcast, 1)); return Status::OK(); } Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { + TF_RETURN_IF_ERROR(CheckOperandCount(broadcast, 1)); // HLO broadcast has no exact analog at the proto level so there is no // ShapeInference method. Check the output shape explicitly. const Shape& operand_shape = broadcast->operand(0)->shape(); @@ -313,14 +358,16 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { operand_dimension < ShapeUtil::Rank(operand_shape); ++operand_dimension) { int64 output_dimension = broadcast->dimensions()[operand_dimension]; - TF_RET_CHECK(broadcast->shape().dimensions(output_dimension) == - operand_shape.dimensions(operand_dimension)) + TF_RET_CHECK((output_dimension < ShapeUtil::Rank(broadcast->shape())) && + (broadcast->shape().dimensions(output_dimension) == + operand_shape.dimensions(operand_dimension))) << broadcast->ToString() << " operand shape " << operand_shape; } return Status::OK(); } Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { + TF_RETURN_IF_ERROR(CheckOperandCount(reshape, 1)); // Check for mixed precision. TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape())); TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) == @@ -329,17 +376,27 @@ Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { } Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { + TF_RETURN_IF_ERROR(CheckOperandCount(transpose, 1)); return CheckShape( transpose, ShapeInference::InferTransposeShape( transpose->operand(0)->shape(), transpose->dimensions())); } Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 0)); return Status::OK(); } Status ShapeVerifier::HandleFusion(HloInstruction* fusion) { - for (HloInstruction* fused_param : fusion->fused_parameters()) { + auto& fused_parameters = fusion->fused_parameters(); + if (fused_parameters.size() != fusion->operand_count()) { + return InternalError( + "Fused parameter count (%d) does not match the number of operands (%d)" + " passed to the fusion instruction in: %s.", + fused_parameters.size(), fusion->operand_count(), + fusion->ToString().c_str()); + } + for (HloInstruction* fused_param : fused_parameters) { int64 param_no = fused_param->parameter_number(); if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) { return InternalError( @@ -359,9 +416,30 @@ Status ShapeVerifier::HandleCall(HloInstruction* call) { return CheckShape(call, call->to_apply()->root_instruction()->shape()); } -Status ShapeVerifier::HandleCustomCall(HloInstruction*) { return Status::OK(); } +Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) { + const HloCustomCallInstruction* custom_call = + DynCast(instruction); + TF_RET_CHECK(custom_call != nullptr); + if (custom_call->layout_constrained()) { + // If the layout is constrained, verify all the respective shapes have + // layouts and that the constrained operand shapes match the shapes of the + // operands. + TF_RET_CHECK(LayoutUtil::HasLayout(custom_call->shape())); + TF_RET_CHECK(custom_call->operand_count() == + custom_call->operand_shapes_with_layout().size()); + for (int64 i = 0; i < custom_call->operand_count(); ++i) { + const Shape& operand_shape_with_layout = + custom_call->operand_shapes_with_layout()[i]; + TF_RET_CHECK(ShapeUtil::Compatible(custom_call->operand(i)->shape(), + operand_shape_with_layout)); + TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout)); + } + } + return Status::OK(); +} Status ShapeVerifier::HandleSlice(HloInstruction* slice) { + TF_RETURN_IF_ERROR(CheckOperandCount(slice, 1)); return CheckShape(slice, ShapeInference::InferSliceShape( slice->operand(0)->shape(), slice->slice_starts(), @@ -369,6 +447,7 @@ Status ShapeVerifier::HandleSlice(HloInstruction* slice) { } Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) { + TF_RETURN_IF_ERROR(CheckOperandCount(dynamic_slice, 2)); return CheckShape(dynamic_slice, ShapeInference::InferDynamicSliceShape( dynamic_slice->operand(0)->shape(), dynamic_slice->operand(1)->shape(), @@ -377,6 +456,7 @@ Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) { Status ShapeVerifier::HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) { + TF_RETURN_IF_ERROR(CheckOperandCount(dynamic_update_slice, 3)); return CheckShape(dynamic_update_slice, ShapeInference::InferDynamicUpdateSliceShape( dynamic_update_slice->operand(0)->shape(), @@ -406,6 +486,7 @@ Status ShapeVerifier::HandleMap(HloInstruction* map) { } Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) { + TF_RETURN_IF_ERROR(CheckOperandCount(reduce_window, 2)); return CheckShape( reduce_window, ShapeInference::InferReduceWindowShape( @@ -415,6 +496,7 @@ Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) { } Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 3)); return CheckShape( instruction, ShapeInference::InferSelectAndScatterShape( @@ -425,6 +507,7 @@ Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) { } Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { + TF_RETURN_IF_ERROR(CheckOperandCount(xla_while, 1)); TF_RETURN_IF_ERROR( CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0)); TF_RETURN_IF_ERROR( @@ -444,6 +527,7 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { } Status ShapeVerifier::HandleConditional(HloInstruction* conditional) { + TF_RETURN_IF_ERROR(CheckOperandCount(conditional, 3)); TF_RETURN_IF_ERROR(CheckOperandAndParameter( conditional, 1, conditional->true_computation(), 0)); TF_RETURN_IF_ERROR(CheckOperandAndParameter( @@ -458,12 +542,14 @@ Status ShapeVerifier::HandleConditional(HloInstruction* conditional) { } Status ShapeVerifier::HandlePad(HloInstruction* pad) { + TF_RETURN_IF_ERROR(CheckOperandCount(pad, 2)); return CheckShape(pad, ShapeInference::InferPadShape(pad->operand(0)->shape(), pad->operand(1)->shape(), pad->padding_config())); } Status ShapeVerifier::HandleSend(HloInstruction* send) { + TF_RETURN_IF_ERROR(CheckOperandCount(send, 2)); return CheckShape(send, ShapeUtil::MakeTupleShape({send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {}), @@ -471,10 +557,12 @@ Status ShapeVerifier::HandleSend(HloInstruction* send) { } Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) { + TF_RETURN_IF_ERROR(CheckOperandCount(send_done, 1)); return CheckShape(send_done, ShapeUtil::MakeTokenShape()); } Status ShapeVerifier::HandleRecv(HloInstruction* recv) { + TF_RETURN_IF_ERROR(CheckOperandCount(recv, 1)); return CheckShape( recv, ShapeUtil::MakeTupleShape( {ShapeUtil::GetTupleElementShape(recv->shape(), 0), @@ -482,6 +570,7 @@ Status ShapeVerifier::HandleRecv(HloInstruction* recv) { } Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { + TF_RETURN_IF_ERROR(CheckOperandCount(recv_done, 1)); return CheckShape( recv_done, ShapeUtil::MakeTupleShape( @@ -491,6 +580,7 @@ Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { Status ShapeVerifier::HandleBatchNormTraining( HloInstruction* batch_norm_training) { + TF_RETURN_IF_ERROR(CheckOperandCount(batch_norm_training, 3)); return CheckShape(batch_norm_training, ShapeInference::InferBatchNormTrainingShape( batch_norm_training->operand(0)->shape(), @@ -501,6 +591,7 @@ Status ShapeVerifier::HandleBatchNormTraining( Status ShapeVerifier::HandleBatchNormInference( HloInstruction* batch_norm_inference) { + TF_RETURN_IF_ERROR(CheckOperandCount(batch_norm_inference, 5)); return CheckShape(batch_norm_inference, ShapeInference::InferBatchNormInferenceShape( batch_norm_inference->operand(0)->shape(), @@ -512,6 +603,7 @@ Status ShapeVerifier::HandleBatchNormInference( } Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) { + TF_RETURN_IF_ERROR(CheckOperandCount(batch_norm_grad, 5)); return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape( batch_norm_grad->operand(0)->shape(), batch_norm_grad->operand(1)->shape(), @@ -548,6 +640,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { case HloOpcode::kTupleSelect: case HloOpcode::kSend: case HloOpcode::kSendDone: + case HloOpcode::kSort: case HloOpcode::kTuple: case HloOpcode::kWhile: break; @@ -579,6 +672,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { } // namespace Status ShapeVerifier::HandleGather(HloInstruction* gather) { + TF_RETURN_IF_ERROR(CheckOperandCount(gather, 2)); return CheckShape( gather, ShapeInference::InferGatherShape( @@ -587,6 +681,7 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) { } Status ShapeVerifier::HandleScatter(HloInstruction* scatter) { + TF_RETURN_IF_ERROR(CheckOperandCount(scatter, 3)); return CheckShape( scatter, ShapeInference::InferScatterShape( scatter->operand(0)->shape(), scatter->operand(1)->shape(), @@ -674,12 +769,14 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 1)); return CheckShape(instruction, ShapeInference::InferUnaryOpShape(instruction->opcode(), instruction->operand(0))); } Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2)); return CheckShape( instruction, ShapeInference::InferBinaryOpShape(instruction->opcode(), instruction->operand(0), @@ -687,6 +784,7 @@ Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) { } Status ShapeVerifier::CheckTernaryShape(const HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 3)); return CheckShape(instruction, ShapeInference::InferTernaryOpShape( instruction->opcode(), instruction->operand(0), @@ -763,7 +861,191 @@ Status VerifyHloStructure(HloModule* module) { return Status::OK(); } -Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { +namespace { + +// Returns true if the given Shape has a TOKEN shape as any subshape. +bool ShapeContainsToken(const Shape& shape) { + bool contains_token = false; + ShapeUtil::ForEachSubshape( + shape, [&contains_token](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsToken(subshape)) { + contains_token = true; + } + }); + return contains_token; +} + +// Verifies that all types entering and exiting the entry computation are +// legal. +Status VerifyEntryAndExitShapes(const HloModule& module) { + // Tokens cannot be passed as entry parameters. + // TODO(b/80000000): Remove this constraint. + for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) { + HloInstruction* param = + module.entry_computation()->parameter_instruction(i); + if (ShapeContainsToken(param->shape())) { + return InternalError( + "Entry parameter %d is or contains a token shape: %s", i, + ShapeUtil::HumanString(param->shape())); + } + } + return Status::OK(); +} + +// Verifies that entry computation layout matches characteristics of +// entry computation. +Status CheckEntryComputationLayout(const HloModule& module) { + const HloComputation* computation = module.entry_computation(); + const auto& layout = module.entry_computation_layout(); + const ShapeLayout& result_layout = layout.result_layout(); + + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(result_layout.shape())); + + if (LayoutUtil::IsSparseArray(result_layout.shape())) { + return Unimplemented( + "Sparse arrays are not yet fully supported in program result shape: %s", + ShapeUtil::HumanStringWithLayout(result_layout.shape())); + } + + if (!ShapeUtil::Compatible(computation->root_instruction()->shape(), + result_layout.shape())) { + return InternalError( + "Shape of the root instruction of entry computation (%s) should be " + "compatible to one specified in module's entry computation layout (%s)", + ShapeUtil::HumanString(computation->root_instruction()->shape()), + ShapeUtil::HumanString(result_layout.shape())); + } + + if (computation->num_parameters() != layout.parameter_count()) { + return InternalError( + "Number of parameters in entry computation layout (%d) must be same " + "as number of parameters of entry computation computation (%d)", + layout.parameter_count(), computation->num_parameters()); + } + + for (int i = 0; i < computation->num_parameters(); ++i) { + const HloInstruction* parameter = computation->parameter_instruction(i); + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(layout.parameter_shape(i))); + if (LayoutUtil::IsSparseArray(layout.parameter_shape(i))) { + return Unimplemented( + "Sparse arrays are not yet fully supported " + "in program parameter shape: %s", + ShapeUtil::HumanStringWithLayout(layout.parameter_shape(i))); + } + if (!ShapeUtil::Compatible(parameter->shape(), layout.parameter_shape(i))) { + return InternalError( + "Shape of the entry computation parameter %d is %s should be " + "compatible to the one specified in module's entry computation " + "layout %s", + i, ShapeUtil::HumanString(parameter->shape()), + ShapeUtil::HumanString(layout.parameter_shape(i))); + } + } + + return Status::OK(); +} + +// Checks if the given two instructions share the same channel id. +Status CheckSameChannel(const HloInstruction* instr1, + const HloInstruction* instr2) { + if (instr1->channel_id() != instr2->channel_id()) { + return InternalError( + "Expected to have the same channel id, actual channel ids are: %s " + "(%d), %s (%d)", + instr1->ToString(), instr1->channel_id(), instr2->ToString(), + instr2->channel_id()); + } + return Status::OK(); +} + +// Checks if the given two instructions have the same is_host_transfer +// attribute value. Intsructions must be send/recv instructions or their +// 'done' variant. +Status CheckSameIsHostTransfer(const HloInstruction* instr1, + const HloInstruction* instr2) { + const HloSendRecvInstruction* send_recv1 = + DynCast(instr1); + const HloSendRecvInstruction* send_recv2 = + DynCast(instr2); + TF_RET_CHECK(send_recv1 != nullptr); + TF_RET_CHECK(send_recv2 != nullptr); + if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) { + return InternalError( + "Expected instructions to have the same is-host-transfer property: " + "%s, " + "%s ", + instr1->ToString(), instr2->ToString()); + } + return Status::OK(); +} + +// Checks various invariants of send and recv instructions. +Status VerifySendsAndRecvs(const HloModule& module) { + absl::flat_hash_map host_channels; + // Host send/recv instructions must have their own unique channel. + auto check_unique_host_channel = [&](const HloInstruction* instruction) { + const HloSendRecvInstruction* sendrecv = + DynCast(instruction); + if (sendrecv->is_host_transfer()) { + auto it_inserted = + host_channels.insert({sendrecv->channel_id(), sendrecv}); + if (!it_inserted.second) { + return FailedPrecondition( + "Channel %d is used for multiple host send/recv instructions: " + "%s " + "and " + "%s", + sendrecv->channel_id(), sendrecv->ToString(), + it_inserted.first->second->ToString()); + } + } + + return Status::OK(); + }; + + // Send/Recv instruction must have a single user: the corresponding + // SendDone/RecvDone. with matching channel. + for (const HloComputation* computation : module.computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + switch (instruction->opcode()) { + case HloOpcode::kSend: { + TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); + TF_RET_CHECK(instruction->users().size() == 1); + const HloInstruction* send_done = instruction->users().front(); + TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); + TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done)); + TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done)); + break; + } + case HloOpcode::kRecv: { + TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); + TF_RET_CHECK(instruction->users().size() == 1); + const HloInstruction* recv_done = instruction->users().front(); + TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); + TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done)); + TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done)); + break; + } + case HloOpcode::kSendDone: + TF_RET_CHECK(instruction->operands().size() == 1); + TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend); + break; + case HloOpcode::kRecvDone: + TF_RET_CHECK(instruction->operands().size() == 1); + TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv); + break; + default: + break; + } + } + } + return Status::OK(); +} + +// CHECKs various invariants of a fusion instruction. +Status CheckFusionInstruction(HloInstruction* fusion) { // The parent fusion instruction of the fusion computation must be 'fusion'. HloComputation* fused_computation = fusion->fused_instructions_computation(); if (fusion != fused_computation->FusionInstruction()) { @@ -866,50 +1148,32 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { } } + TF_RET_CHECK(fusion->called_computations() == + absl::Span( + {fusion->fused_instructions_computation()})) + << "Fusion HLO calls computations other than the " + "fused_instructions_computation: " + << fusion->ToString() << " fusion->fused_instructions_computation(): " + << fusion->fused_instructions_computation()->ToString() + << " fusion->called_computations(): " + << ComputationsToString(fusion->called_computations()); + + for (const auto& fused : fusion->fused_instructions()) { + TF_RET_CHECK(fused->parent() == fusion->fused_instructions_computation()) + << "Fused HLO was missing a parent: " << fused->ToString() + << " parent: " << fused->parent() + << " computation: " << fusion->parent(); + } + // TODO(b/65423525): We'd like to check that all operands are distinct. // This is currently disabled due to the invariant being violated by // multi-output fusion. return Status::OK(); } -Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { - auto* while_cond = instruction->while_condition(); - auto* while_body = instruction->while_body(); - if (while_cond->num_parameters() != 1) { - return FailedPrecondition( - "While condition must have exactly 1 parameter; had %d : %s", - while_cond->num_parameters(), while_cond->ToString()); - } - if (while_body->num_parameters() != 1) { - return FailedPrecondition( - "While body must have exactly 1 parameter; had %d : %s", - while_body->num_parameters(), while_body->ToString()); - } - if (instruction->operand_count() != 1) { - return FailedPrecondition( - "While loop must have exactly one operand; had %d : %s", - instruction->operand_count(), instruction->ToString()); - } - return Status::OK(); -} - -Status HloVerifier::CheckConditionalInstruction(HloInstruction* instruction) { - if (instruction->true_computation()->num_parameters() != 1) { - return FailedPrecondition( - "True computation %s of %s must have 1 parameter insted of %d", - instruction->true_computation()->name(), instruction->ToString(), - instruction->true_computation()->num_parameters()); - } - if (instruction->false_computation()->num_parameters() != 1) { - return FailedPrecondition( - "False computation %s of %s must have 1 parameter insted of %d", - instruction->false_computation()->name(), instruction->ToString(), - instruction->false_computation()->num_parameters()); - } - return Status::OK(); -} - -Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { +// Checks that the non-scalar operand shapes are compatible to the output +// shape, i.e., that there are no implicit broadcasts of size-one dimensions. +Status CheckElementwiseInstruction(HloInstruction* instruction) { const Shape& out_shape = instruction->shape(); for (HloInstruction* operand : instruction->operands()) { const Shape& operand_shape = operand->shape(); @@ -926,202 +1190,167 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { return Status::OK(); } -namespace { +// Visitor which verifies various fields on the HLO instruction. This class does +// not check result shape as that is checked in the ShapeVerifier. +class InstructionVerifier : public DfsHloVisitorWithDefault { + public: + explicit InstructionVerifier(std::function + instruction_can_change_layout_func) + : instruction_can_change_layout_func_( + instruction_can_change_layout_func) {} -// Returns true if the given Shape has a TOKEN shape as any subshape. -bool ShapeContainsToken(const Shape& shape) { - bool contains_token = false; - ShapeUtil::ForEachSubshape( - shape, [&contains_token](const Shape& subshape, const ShapeIndex&) { - if (ShapeUtil::IsToken(subshape)) { - contains_token = true; - } - }); - return contains_token; -} + Status DefaultAction(HloInstruction*) override { return Status::OK(); } -// Verifies that all types entering and exiting the entry computation are -// legal. -Status VerifyEntryAndExitShapes(const HloModule& module) { - // Tokens cannot be passed as entry parameters. - // TODO(b/80000000): Remove this constraint. - for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) { - HloInstruction* param = - module.entry_computation()->parameter_instruction(i); - if (ShapeContainsToken(param->shape())) { - return InternalError( - "Entry parameter %d is or contains a token shape: %s", i, - ShapeUtil::HumanString(param->shape())); - } + Status HandleFusion(HloInstruction* fusion) override { + return CheckFusionInstruction(fusion); } - return Status::OK(); -} -// Checks if the given two instructions share the same channel id. -Status CheckSameChannel(const HloInstruction* instr1, - const HloInstruction* instr2) { - if (instr1->channel_id() != instr2->channel_id()) { - return InternalError( - "Expected to have the same channel id, actual channel ids are: %s " - "(%d), %s (%d)", - instr1->ToString(), instr1->channel_id(), instr2->ToString(), - instr2->channel_id()); + Status HandleBroadcast(HloInstruction* broadcast) override { + // If you see this failure then someone has confused the difference + // between the HLO broadcast op, and the UserComputation broadcast + // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I + // or ComputationLowerer::Visit() + TF_RET_CHECK(broadcast->dimensions().size() == + ShapeUtil::Rank(broadcast->operand(0)->shape())) + << "Broadcast HLO (" << broadcast->ToShortString() + << ") has invalid number of dimensions: " + << broadcast->dimensions().size() + << " != " << ShapeUtil::Rank(broadcast->operand(0)->shape()); + return Status::OK(); } - return Status::OK(); -} -// Checks if the given two instructions have the same is_host_transfer -// attribute value. Intsructions must be send/recv instructions or their -// 'done' variant. -Status CheckSameIsHostTransfer(const HloInstruction* instr1, - const HloInstruction* instr2) { - const HloSendRecvInstruction* send_recv1 = - DynCast(instr1); - const HloSendRecvInstruction* send_recv2 = - DynCast(instr2); - TF_RET_CHECK(send_recv1 != nullptr); - TF_RET_CHECK(send_recv2 != nullptr); - if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) { - return InternalError( - "Expected instructions to have the same is-host-transfer property: " - "%s, " - "%s ", - instr1->ToString(), instr2->ToString()); + Status HandleWhile(HloInstruction* xla_while) override { + auto* while_cond = xla_while->while_condition(); + auto* while_body = xla_while->while_body(); + if (while_cond->num_parameters() != 1) { + return FailedPrecondition( + "While condition must have exactly 1 parameter; had %d : %s", + while_cond->num_parameters(), while_cond->ToString()); + } + if (while_body->num_parameters() != 1) { + return FailedPrecondition( + "While body must have exactly 1 parameter; had %d : %s", + while_body->num_parameters(), while_body->ToString()); + } + if (xla_while->operand_count() != 1) { + return FailedPrecondition( + "While loop must have exactly one operand; had %d : %s", + xla_while->operand_count(), xla_while->ToString()); + } + return Status::OK(); } - return Status::OK(); -} -// Checks various invariants of send and recv instructions. -Status VerifySendsAndRecvs(const HloModule& module) { - tensorflow::gtl::FlatMap host_channels; - // Host send/recv instructions must have their own unique channel. - auto check_unique_host_channel = [&](const HloInstruction* instruction) { - const HloSendRecvInstruction* sendrecv = - DynCast(instruction); - if (sendrecv->is_host_transfer()) { - auto it_inserted = - host_channels.insert({sendrecv->channel_id(), sendrecv}); - if (!it_inserted.second) { - return FailedPrecondition( - "Channel %d is used for multiple host send/recv instructions: " - "%s " - "and " - "%s", - sendrecv->channel_id(), sendrecv->ToString(), - it_inserted.first->second->ToString()); - } + Status HandleConditional(HloInstruction* conditional) override { + if (conditional->true_computation()->num_parameters() != 1) { + return FailedPrecondition( + "True computation %s of %s must have 1 parameter insted of %d", + conditional->true_computation()->name(), conditional->ToString(), + conditional->true_computation()->num_parameters()); } + if (conditional->false_computation()->num_parameters() != 1) { + return FailedPrecondition( + "False computation %s of %s must have 1 parameter insted of %d", + conditional->false_computation()->name(), conditional->ToString(), + conditional->false_computation()->num_parameters()); + } + return Status::OK(); + } + + Status HandleElementwiseUnary(HloInstruction* instruction) override { + return CheckElementwiseInstruction(instruction); + } + + Status HandleElementwiseBinary(HloInstruction* instruction) override { + return CheckElementwiseInstruction(instruction); + } + Status HandleGetTupleElement(HloInstruction* gte) override { + TF_RET_CHECK(ShapeUtil::IsTuple(gte->operand(0)->shape())); return Status::OK(); - }; + } - // Send/Recv instruction must have a single user: the corresponding - // SendDone/RecvDone. with matching channel. - for (const HloComputation* computation : module.computations()) { - for (const HloInstruction* instruction : computation->instructions()) { - switch (instruction->opcode()) { - case HloOpcode::kSend: { - TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); - TF_RET_CHECK(instruction->users().size() == 1); - const HloInstruction* send_done = instruction->users().front(); - TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); - TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done)); - TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done)); - break; - } - case HloOpcode::kRecv: { - TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); - TF_RET_CHECK(instruction->users().size() == 1); - const HloInstruction* recv_done = instruction->users().front(); - TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); - TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done)); - TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done)); - break; + Status HandleTranspose(HloInstruction* transpose) override { + const Shape& shape = transpose->shape(); + const HloInstruction* operand = transpose->operand(0); + TF_RET_CHECK(shape.dimensions().size() == transpose->dimensions().size()); + TF_RET_CHECK(shape.dimensions().size() == + transpose->operand(0)->shape().dimensions().size()); + TF_RET_CHECK(std::equal( + operand->shape().dimensions().begin(), + operand->shape().dimensions().end(), + Permute(transpose->dimensions(), shape.dimensions()).begin())) + << "shape: " << shape << ", operand->shape(): " << shape + << ", dimensions: {" << absl::StrJoin(transpose->dimensions(), ", ") + << "}"; + return Status::OK(); + } + + Status Preprocess(HloInstruction* instruction) override { + auto previous = instructions_by_name_.find(instruction->name()); + TF_RET_CHECK(previous == instructions_by_name_.end()) + << "HLO has name that is not unique within module:\n" + << instruction->ToString() + << " in computation: " << instruction->parent()->name() + << "\nPrevious HLO with same name:\n" + << previous->second->ToString() + << " in computation: " << previous->second->parent()->name(); + instructions_by_name_[instruction->name()] = instruction; + return Status::OK(); + } + + Status Postprocess(HloInstruction* instruction) override { + if (instruction_can_change_layout_func_ && + LayoutUtil::IsDenseArray(instruction->shape()) && + !instruction_can_change_layout_func_(instruction)) { + const Shape& result_shape = instruction->shape(); + const Layout& result_layout = result_shape.layout(); + for (HloInstruction* operand : instruction->operands()) { + const Shape& operand_shape = operand->shape(); + if (LayoutUtil::IsDenseArray(operand_shape) && + ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(result_shape)) { + const Layout& operand_layout = operand_shape.layout(); + TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout)) + << "Instruction shouldn't change layouts " + << instruction->ToString() << " From " + << ShapeUtil::HumanString(result_shape) << " To " + << ShapeUtil::HumanString(operand_shape); } - case HloOpcode::kSendDone: - TF_RET_CHECK(instruction->operands().size() == 1); - TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend); - break; - case HloOpcode::kRecvDone: - TF_RET_CHECK(instruction->operands().size() == 1); - TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv); - break; - default: - break; } } + + return Status::OK(); } - return Status::OK(); -} + + private: + absl::flat_hash_map instructions_by_name_; + // Determines whether an instruction can change layouts. + std::function + instruction_can_change_layout_func_; +}; } // namespace StatusOr HloVerifier::Run(HloModule* module) { TF_RET_CHECK(!module->name().empty()); + + if (module->entry_computation()->IsFusionComputation()) { + return InvalidArgument( + "Module entry computation cannot be a fusion computation"); + } + TF_RETURN_IF_ERROR(VerifyHloStructure(module)); TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module)); - tensorflow::gtl::FlatMap instructions; - for (auto* computation : module->computations()) { - for (const auto& instruction : computation->instructions()) { - TF_RET_CHECK(instruction->parent() == computation); - if (instruction->opcode() == HloOpcode::kFusion) { - TF_RETURN_IF_ERROR(CheckFusionInstruction(instruction)); - TF_RET_CHECK(instruction->called_computations() == - absl::Span( - {instruction->fused_instructions_computation()})) - << "Fusion HLO calls computations other than the " - "fused_instructions_computation: " - << instruction->ToString() - << " instruction->fused_instructions_computation(): " - << instruction->fused_instructions_computation()->ToString() - << " instruction->called_computations(): " - << ComputationsToString(instruction->called_computations()); - - for (const auto& fused : instruction->fused_instructions()) { - TF_RET_CHECK(fused->parent() == - instruction->fused_instructions_computation()) - << "Fused HLO was missing a parent: " << fused->ToString() - << " parent: " << fused->parent() - << " computation: " << computation; - } - } else if (instruction->opcode() == HloOpcode::kBroadcast) { - // If you see this failure then someone has confused the difference - // between the HLO broadcast op, and the UserComputation broadcast - // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I - // or ComputationLowerer::Visit() - TF_RET_CHECK(instruction->dimensions().size() == - ShapeUtil::Rank(instruction->operand(0)->shape())) - << "Broadcast HLO (" << instruction->ToShortString() - << ") has invalid number of dimensions: " - << instruction->dimensions().size() - << " != " << ShapeUtil::Rank(instruction->operand(0)->shape()); - } else if (instruction->opcode() == HloOpcode::kWhile) { - TF_RETURN_IF_ERROR(CheckWhileInstruction(instruction)); - } else if (instruction->opcode() == HloOpcode::kConditional) { - TF_RETURN_IF_ERROR(CheckConditionalInstruction(instruction)); - } else if (instruction->opcode() != - HloOpcode::kRng /* Rng operands are always scalar. */ - && instruction->IsElementwise()) { - TF_RETURN_IF_ERROR(CheckElementwiseInstruction(instruction)); - } - - auto previous = instructions.find(instruction->name()); - TF_RET_CHECK(previous == instructions.end()) - << "HLO has name that is not unique within module:\n" - << instruction->ToString() - << " in computation: " << computation->name() - << "\nPrevious HLO with same name:\n" - << previous->second->ToString() - << " in computation: " << previous->second->parent()->name(); - instructions[instruction->name()] = instruction; - } - std::unique_ptr shape_verifier = shape_verifier_factory_(); TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get())); + + InstructionVerifier instruction_verifier( + instruction_can_change_layout_func_); + TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier)); } + TF_RETURN_IF_ERROR(CheckEntryComputationLayout(*module)); TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module)); // If the module has a schedule, it must be valid. @@ -1129,6 +1358,8 @@ StatusOr HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(module->schedule().Verify()); } + TF_RETURN_IF_ERROR(module->input_output_alias_config().Verify(*module)); + return false; } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 0cde4a31af72e81829723c564f59edc362f73335..e1f3402465746b0478d7bb7e4ee2b66e3f839eb2 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -32,6 +32,8 @@ class ShapeVerifier : public DfsHloVisitor { : layout_sensitive_(layout_sensitive), allow_mixed_precision_(allow_mixed_precision) {} + Status Preprocess(HloInstruction* hlo) override; + Status HandleElementwiseUnary(HloInstruction* hlo) override; Status HandleElementwiseBinary(HloInstruction* hlo) override; Status HandleClamp(HloInstruction* clamp) override; @@ -155,11 +157,17 @@ class HloVerifier : public HloModulePass { public: using ShapeVerifierFactory = std::function()>; - explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision) + explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision, + std::function + instruction_can_change_layout_func = {}) : shape_verifier_factory_([layout_sensitive, allow_mixed_precision] { return absl::make_unique(layout_sensitive, allow_mixed_precision); - }) {} + }), + instruction_can_change_layout_func_( + std::move(instruction_can_change_layout_func)) { + CHECK(instruction_can_change_layout_func_ == nullptr || layout_sensitive); + } // Uses custom shape verification. explicit HloVerifier(ShapeVerifierFactory shape_verifier_factory) @@ -172,22 +180,15 @@ class HloVerifier : public HloModulePass { StatusOr Run(HloModule* module) override; private: - // CHECKs various invariants of a fusion instruction. - Status CheckFusionInstruction(HloInstruction* fusion) const; - - Status CheckWhileInstruction(HloInstruction* instruction); - - Status CheckConditionalInstruction(HloInstruction* instruction); - - // Checks that the non-scalar operand shapes are compatible to the output - // shape, i.e., that there are no implicit broadcasts of size-one dimensions. - Status CheckElementwiseInstruction(HloInstruction* instruction); - // Creates a ShapeVerifier that checks that shapes match inferred // expectations. This is a factory function because ShapeVerifier, // being a DfsHloVisitor, is stateful. We want a clean object // for each run of the verifier. ShapeVerifierFactory shape_verifier_factory_; + + // Determines whether an instruction can change layouts. + std::function + instruction_can_change_layout_func_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 8f0423bb1c72ceb209437116a898d027f4d2c657..afe01e5487c3225815e01343d86e9fe894c2cde8 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -50,6 +51,14 @@ class HloVerifierTestAllowMixedPrecision : public HloTestBase { /*allow_mixed_precision_in_hlo_verifier=*/true) {} }; +class HloVerifierTestLayoutSensitive : public HloTestBase { + public: + HloVerifierTestLayoutSensitive() + : HloTestBase(/*verifier_layout_sensitive=*/true, + /*allow_mixed_precision_in_hlo_verifier=*/false, + LayoutAssignment::InstructionCanChangeLayout) {} +}; + TEST_F(HloVerifierTest, NullInstructionParent) { HloComputation::Builder builder(TestName()); const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -358,5 +367,63 @@ TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) { HasSubstr("non-positive base area dilation factor")); } +static const char* const kAddWithLayoutChangeHlo = R"( + HloModule AddWithLayoutChange + ENTRY AddWithLayoutChange { + par0 = f32[3,4]{1,0} parameter(0) + par1 = f32[3,4]{0,1} parameter(1) + ROOT add0 = f32[3,4]{1,0} add(par0,par1) + } + )"; + +TEST_F(HloVerifierTest, AddWithLayoutChange) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Instruction shouldn't change layouts")); +} + +TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) { + const char* const kSliceWithLayoutChangeHlo = R"( + HloModule SliceWithLayoutChange + ENTRY SliceWithLayoutChange { + par0 = f32[4,5]{0,1} parameter(0) + par1 = s32[2] parameter(1) + ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1), + dynamic_slice_sizes={3,4} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kSliceWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Instruction shouldn't change layouts")); +} + +TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) { + const char* const kConcatWithLayoutChangeHlo = R"( + HloModule ConcatWithLayoutChange + ENTRY ConcatWithLayoutChange { + par0 = f32[3,5]{0,1} parameter(0) + par1 = f32[3,3]{1,0} parameter(1) + ROOT concat0 = f32[3,8]{1,0} concatenate(f32[3,5] par0, f32[3,3] par1), + dimensions={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kConcatWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Instruction shouldn't change layouts")); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 06f0e1ed25e71659a61e6de8a84e52cf70064eae..1ebb3319779c00fd4afe90606bf336e16349429d 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -23,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { namespace gtl = ::tensorflow::gtl; @@ -95,7 +96,7 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache( absl::InlinedVector stack; enum DfsState { kDiscovered, kVisited }; - gtl::FlatMap dfs_state_map; + absl::flat_hash_map dfs_state_map; stack.push_back(root); InsertOrDie(&dfs_state_map, root, kDiscovered); diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index 3e238f97a03fb71cddf59da69b0389731314ff49..e5aa67fd850de647652d66017223e19fb92cc068 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -18,10 +18,10 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/util/ptr_util.h" namespace xla { @@ -360,7 +360,7 @@ class IndexedArrayAnalysis { std::vector> owned_tensors_; std::vector owned_literals_; - tensorflow::gtl::FlatMap cache_; + absl::flat_hash_map cache_; }; // A pass that prints all non-trivial results returned by IndexedArrayAnalysis. diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index e884122fcb4042fe54a419c7a527d49604f97bd7..69a4c160ee5c4539272c3085338dc6de1b9347ff 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -22,11 +22,12 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/fusion_queue.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -189,7 +190,7 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) { bool InstructionFusion::CanFuseOnAllPaths( HloInstruction* producer, HloInstruction* consumer, const HloInstructionSet& do_not_fuse, - tensorflow::gtl::FlatMap, bool>* + absl::flat_hash_map, bool>* result_cache) { if (consumer == producer) { return true; @@ -241,7 +242,7 @@ InstructionFusion::ComputeGloballyUnfusible( // fusing operations that require duplication later depending on // is_expensive_(). HloInstructionSet do_not_duplicate; - tensorflow::gtl::FlatMap, bool> + absl::flat_hash_map, bool> can_fuse_on_all_paths_result_cache; for (HloInstruction* consumer : post_order) { for (HloInstruction* producer : consumer->operands()) { @@ -430,7 +431,7 @@ class ReversePostOrderFusionQueue : public FusionQueue { private: std::vector post_order_; - tensorflow::gtl::FlatMap post_order_index_; + absl::flat_hash_map post_order_index_; }; } // namespace diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index c1ec3b18a19724f5068f7099cfe19547e7bb2784..f14c6675208c72112aea0179c238b58709d625b5 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -1,3 +1,4 @@ +#include "absl/container/flat_hash_map.h" /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,6 +17,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_ +#include "tensorflow/compiler/xla/service/fusion_queue.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -24,33 +26,6 @@ limitations under the License. namespace xla { -// A queue interface that allows implementations to choose fusion candidates in -// custom order. -class FusionQueue { - public: - FusionQueue() = default; - virtual ~FusionQueue() = default; - - // Dequeues the next fusion candidates: a consumer and the list of producers - // as operand indices. - virtual std::pair> - DequeueNextInstructionAndOperandsToFuseInOrder() = 0; - - // A callback passed to the queue implementation right before the producer is - // fused into the consumer. - virtual void PreFusion(HloInstruction* producer, HloInstruction* consumer) {} - - // A callback passed to the queue implementation right after the fusion is - // created. Note that original_producer could have been destroyed. - virtual void OnFusingInstruction(HloInstruction* fusion, - HloInstruction* original_producer, - HloInstruction* original_consumer) {} - - // A callback passed to the queue implementation to notify the removal of an - // instruction. - virtual void RemoveInstruction(HloInstruction* instruction) = 0; -}; - // HLO pass which performs instruction fusion. Instructions are fused // "vertically", meaning producing instructions are fused into their consumers // with the intent that the loops which compute their values will be fused in @@ -158,8 +133,8 @@ class InstructionFusion : public HloModulePass { bool CanFuseOnAllPaths( HloInstruction* producer, HloInstruction* consumer, const HloInstructionSet& do_not_fuse, - tensorflow::gtl::FlatMap, - bool>* result_cache); + absl::flat_hash_map, bool>* + result_cache); // Computes the set of nodes that we do not want to fuse into any of their // consumers based on a global analysis of the HLO graph. diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 146c9052f10cca8b199a480491d9a672d8bebdff..1484e14df10d94841c5a2e849761779f5800392d 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -45,8 +45,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", - "//tensorflow/compiler/xla/service:inliner", "//tensorflow/compiler/xla/service:layout_assignment", + "//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index bb69cb9c47ff2c7de8d13832c4b8e6216c62da73..a1fe97cffa4de1993396d2443166321bc795b553 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -28,9 +28,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" -#include "tensorflow/compiler/xla/service/inliner.h" #include "tensorflow/compiler/xla/service/interpreter/executable.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" +#include "tensorflow/compiler/xla/service/map_inliner.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -44,7 +44,8 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); pipeline.AddPass( - hlo_module->mutable_entry_computation_layout()); + hlo_module->mutable_entry_computation_layout(), + LayoutAssignment::InstructionCanChangeLayout); return pipeline.Run(hlo_module).status(); } @@ -56,6 +57,12 @@ StatusOr> InterpreterCompiler::RunHloPasses( return std::move(hlo_module); } +Status InterpreterCompiler::RunHloPassesOnModuleGroup( + HloModuleGroup* module_group, se::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator) { + return Unimplemented("Module group compilation not supported on Interpreter"); +} + StatusOr> InterpreterCompiler::RunBackend( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* /*device_allocator*/) { @@ -75,17 +82,45 @@ StatusOr> InterpreterCompiler::RunBackend( return std::move(executable); } +StatusOr>> +InterpreterCompiler::RunBackendOnModuleGroup( + std::unique_ptr module_group, + std::vector> stream_exec, + DeviceMemoryAllocator* device_allocator) { + return Unimplemented( + "Module group compilation is not supported on Interpreter."); +} + StatusOr>> InterpreterCompiler::Compile( - std::vector> /*hlo_modules*/, - std::vector> /*stream_execs*/, - DeviceMemoryAllocator* /*device_allocator*/) { - return tensorflow::errors::Unimplemented( - "Compilation of multiple HLO modules is not supported on Interpreter."); + std::unique_ptr module_group, + std::vector> stream_exec, + DeviceMemoryAllocator* device_allocator) { + if (module_group->empty()) { + return std::vector>(); + } + if (module_group->size() > 1) { + return tensorflow::errors::Unimplemented( + "Compilation of multiple HLO modules is not supported on Interpreter."); + } + if (stream_exec.size() != 1 || stream_exec[0].size() != 1) { + return tensorflow::errors::Unimplemented( + "Unexpected number of StreamExecutor's."); + } + auto hlo_modules = module_group->ConsumeModules(); + TF_ASSIGN_OR_RETURN(auto module, + RunHloPasses(std::move(hlo_modules[0]), stream_exec[0][0], + device_allocator)); + TF_ASSIGN_OR_RETURN( + auto executable, + RunBackend(std::move(module), stream_exec[0][0], device_allocator)); + std::vector> ret; + ret.push_back(std::move(executable)); + return std::move(ret); } StatusOr>> InterpreterCompiler::CompileAheadOfTime( - std::vector> hlo_modules, + std::unique_ptr module_group, const AotCompilationOptions& aot_options) { return tensorflow::errors::InvalidArgument( "AOT compilation not supported on Interpreter"); diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.h b/tensorflow/compiler/xla/service/interpreter/compiler.h index e90ae3e818522e6e4fd9d9f5acb846800bc899ca..d8cb32c0beb279ae6484b1b8f5f99085c2d67c67 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.h +++ b/tensorflow/compiler/xla/service/interpreter/compiler.h @@ -46,18 +46,25 @@ class InterpreterCompiler : public Compiler { StatusOr> RunHloPasses( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) override; + Status RunHloPassesOnModuleGroup( + HloModuleGroup* module_group, se::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator) override; StatusOr> RunBackend( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) override; + StatusOr>> RunBackendOnModuleGroup( + std::unique_ptr module_group, + std::vector> stream_exec, + DeviceMemoryAllocator* device_allocator) override; StatusOr>> Compile( - std::vector> hlo_modules, + std::unique_ptr module_group, std::vector> stream_exec, DeviceMemoryAllocator* device_allocator) override; StatusOr>> - CompileAheadOfTime(std::vector> hlo_modules, + CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& aot_options) override; HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override; diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 082bf8bffed484244139e79f4d3fe30ca091d8ac..232d1dc0879cd6931158e642e01fe68e43e6c655 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -419,6 +419,16 @@ Status LayoutAssignment::BuildHostChannelConstraints( return Status::OK(); } +namespace { + +bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) { + const HloCustomCallInstruction* custom_call = + DynCast(instruction); + return custom_call != nullptr && custom_call->layout_constrained(); +} + +} // namespace + Status LayoutAssignment::AddMandatoryConstraints( const ComputationLayout* computation_layout, ChannelLayoutConstraints* channel_constraints, HloComputation* computation, @@ -434,7 +444,6 @@ Status LayoutAssignment::AddMandatoryConstraints( // Constrain layouts of instructions which define values with pre-existing // layouts. for (auto* instruction : computation->instructions()) { - Shape const* shape_with_layout = nullptr; if (instruction->opcode() == HloOpcode::kInfeed) { // Infeed layouts must match the layout of the original inserted // instruction. @@ -456,17 +465,21 @@ Status LayoutAssignment::AddMandatoryConstraints( if (parameter_layout.LayoutIsSet()) { // Parameter layouts must match the respective layout in // ComputationLayout, if there is one. - shape_with_layout = ¶meter_layout.shape(); + TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( + parameter_layout.shape(), instruction)); } } - } - if (shape_with_layout != nullptr) { + } else if (IsLayoutConstrainedCustomCall(instruction)) { + const HloCustomCallInstruction* custom_call = + DynCast(instruction); TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(*shape_with_layout, instruction)); - } - - if (instruction->opcode() == HloOpcode::kSend || - instruction->opcode() == HloOpcode::kRecv) { + constraints->SetInstructionLayout(custom_call->shape(), custom_call)); + for (int64 i = 0; i < custom_call->operand_count(); ++i) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + custom_call->operand_shapes_with_layout()[i], custom_call, i)); + } + } else if (instruction->opcode() == HloOpcode::kSend || + instruction->opcode() == HloOpcode::kRecv) { CHECK(get_channel_constraints(instruction)) << "Multi-module layout assignment requires ChannelLayoutConstraints"; int64 channel_id = instruction->channel_id(); @@ -498,6 +511,22 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR( constraints->SetBufferLayout(new_shape.layout(), *buffer)); } + } else if (instruction->IsCrossModuleAllReduce()) { + CHECK(get_channel_constraints(instruction)) + << "Multi-module layout assignment requires ChannelLayoutConstraints"; + int64 all_reduce_id = instruction->all_reduce_id().value(); + if (!get_channel_constraints(instruction) + ->IsChannelConstrained(all_reduce_id)) { + continue; + } + // TODO(b/68493863): Change to use SetOperandLayout(). + const Shape& buffer_shape = instruction->operand(0)->shape(); + TF_RET_CHECK(ShapeUtil::IsArray(buffer_shape)); + Shape new_buffer_shape = + get_channel_constraints(instruction) + ->LayoutShapeForChannel(buffer_shape, all_reduce_id); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(new_buffer_shape, instruction)); } } @@ -605,31 +634,6 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR(constraints->SetOperandLayout( false_computation_layout.parameter_shape(0), instruction, 2, /*mandatory=*/true)); - } else if (instruction->opcode() == HloOpcode::kCustomCall) { - if (!CustomCallRequiresMajorFirstLayout(instruction)) { - continue; - } - // Add constraints for kCustomCall instruction operands and instructions. - // For now we only support major-first layouts for all inputs and outputs. - Shape result_shape = ShapeUtil::MakeShapeWithDescendingLayout( - instruction->shape().element_type(), - AsInt64Slice(instruction->shape().dimensions())); - TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(result_shape, instruction)); - for (int64 i = 0; i < instruction->operand_count(); ++i) { - const Shape& operand_shape = instruction->operand(i)->shape(); - // Opaque operands don't get a layout constraint. - if (ShapeUtil::IsOpaque(operand_shape)) { - continue; - } - - Shape row_major_operand_shape = - ShapeUtil::MakeShapeWithDescendingLayout( - operand_shape.element_type(), - AsInt64Slice(operand_shape.dimensions())); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - row_major_operand_shape, instruction, i)); - } } } // Finally set the result layout to match ComputationLayout, if there is one. @@ -660,16 +664,18 @@ Status CheckCallLayout(HloInstruction* call, return Status::OK(); } -// Custom calls have fixed input and output layouts. -Status CheckCustomCallLayout(HloInstruction* custom_call) { - for (const HloInstruction* operand : custom_call->operands()) { - TF_RET_CHECK( - ShapeUtil::IsOpaque(operand->shape()) || - LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout())); +// Operands of layout-constrained custom calls must match the expected +// constrained layouts. +Status CheckCustomCallLayout(HloInstruction* instruction) { + if (IsLayoutConstrainedCustomCall(instruction)) { + const HloCustomCallInstruction* custom_call = + DynCast(instruction); + for (int64 i = 0; i < custom_call->operand_count(); ++i) { + TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( + custom_call->operand(i)->shape(), + custom_call->operand_shapes_with_layout()[i])); + } } - TF_RET_CHECK( - ShapeUtil::IsOpaque(custom_call->shape()) || - LayoutUtil::IsMonotonicWithDim0Major(custom_call->shape().layout())); return Status::OK(); } @@ -776,21 +782,27 @@ StatusOr LayoutAssignment::CreateCopyWithNewLayout( << " instruction: " << instruction->ToString(); if (ShapeUtil::IsTuple(instruction->shape())) { - // Deep-copy tuples. + // Copy tuple elements which have differing layouts. std::vector element_copies; for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); ++i) { + const Shape& target_shape = + ShapeUtil::GetSubshape(shape_with_layout, {i}); + const Shape& instr_shape = + ShapeUtil::GetSubshape(instruction->shape(), {i}); HloInstruction* gte = instruction->parent()->AddInstruction( - HloInstruction::CreateGetTupleElement( - ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, - i)); - SetupCopiedInstruction(*instruction, gte, {i}); - // Recurse to copy each elements. - TF_ASSIGN_OR_RETURN( - HloInstruction * element_copy, - CreateCopyWithNewLayout( - ShapeUtil::GetSubshape(shape_with_layout, {i}), gte)); - element_copies.push_back(element_copy); + HloInstruction::CreateGetTupleElement(instr_shape, instruction, i)); + + if (ShapeUtil::Equal(target_shape, instr_shape)) { + // Shapes and layouts are equal, no need to copy. + element_copies.push_back(gte); + } else { + SetupCopiedInstruction(*instruction, gte, {i}); + // Recurse to copy each element. + TF_ASSIGN_OR_RETURN(HloInstruction * element_copy, + CreateCopyWithNewLayout(target_shape, gte)); + element_copies.push_back(element_copy); + } } // Gather element copies into a tuple with a new Tuple instruction. HloInstruction* tuple_copy = instruction->parent()->AddInstruction( @@ -910,9 +922,7 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { FindOrDie(computation_layouts_, instruction->to_apply()))); break; case HloOpcode::kCustomCall: - if (CustomCallRequiresMajorFirstLayout(instruction)) { - TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction)); - } + TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction)); break; case HloOpcode::kFusion: TF_RETURN_IF_ERROR(CheckFusionLayout(instruction)); @@ -949,19 +959,23 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { FindOrDie(computation_layouts_, module->entry_computation()) .result_layout(); if (result_layout.LayoutIsSet()) { - TF_RET_CHECK(ShapeUtil::Equal( - module->entry_computation()->root_instruction()->shape(), - result_layout.shape())); + TF_RET_CHECK( + ShapeUtil::Equal(module->result_shape(), result_layout.shape())); } return Status::OK(); } LayoutAssignment::LayoutAssignment( ComputationLayout* entry_computation_layout, + std::function + instruction_can_change_layout_func, ChannelLayoutConstraints* channel_constraints) : entry_computation_layout_(entry_computation_layout), + saved_entry_computation_layout_(*entry_computation_layout), - channel_layout_constraints_(channel_constraints) { + channel_layout_constraints_(channel_constraints), + instruction_can_change_layout_func_( + std::move(instruction_can_change_layout_func)) { if (channel_layout_constraints_ != nullptr) { // Save a copy of the input ChannelLayoutConstraints so that we can reset it // if we have to undo previous operations (ClearPreviousPassSideEffects()). @@ -982,7 +996,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( if (!ShapeUtil::IsScalar(operand->shape()) && ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(instruction->shape()) && - InstructionRequiresInputLayoutEqualToOutputLayout(instruction)) { + !instruction_can_change_layout_func_(instruction)) { // Propagate the result layout to the operand layout if the instruction // requires the same layout out for the result and the operand. // @@ -1060,7 +1074,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( if (!ShapeUtil::IsScalar(operand->shape()) && ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape()) && - InstructionRequiresInputLayoutEqualToOutputLayout(user)) { + !instruction_can_change_layout_func_(user)) { // Assign users the same layout as the operand. return absl::make_unique(operand_layout); } @@ -1509,22 +1523,13 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, // Execute extra verification step once the layout has been finalized. TF_RETURN_IF_ERROR(Verify(instruction)); + // Shape must be valid. + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape())); + // Verify all layouts in the shape have been set. TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape())); } - - // Copy the root instruction's result if its layout does not match the result - // layout constraint. - if (constraints.ResultLayout() != nullptr && - !constraints.ResultLayout()->MatchesLayoutInShape( - computation->root_instruction()->shape())) { - TF_ASSIGN_OR_RETURN( - HloInstruction * new_root, - CreateCopyWithNewLayout(constraints.ResultLayout()->shape(), - computation->root_instruction())); - computation->set_root_instruction(new_root); - } - return Status::OK(); } @@ -1540,11 +1545,11 @@ Status LayoutAssignment::CalculateComputationLayout( Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { // Clear existing layouts of the instructions. All layouts must be assigned - // by the LayoutAssignment pass, except for those on infeeds, parameters, - // and the computation result. The latter two are specified in - // computation_layout, so we only need to keep the existing layouts for - // infeeds. Clearing the layouts here avoids hiding potential bugs in the - // layout assignment pass that may accidentally use the existing layout. + // by the LayoutAssignment pass, except for those on parameters, the + // computation result, and a couple special cases. The former two are + // specified in computation_layout. Clearing the layouts here avoids hiding + // potential bugs in the layout assignment pass that may accidentally use the + // existing layout. for (HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kBitcast) { // bitcasts are inherently layout sensitive and so a bitcast instruction @@ -1553,7 +1558,9 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { "Unexpected bitcast operation seen during layout assignment: %s.", instruction->ToString()); } - if (instruction->opcode() != HloOpcode::kInfeed) { + // Some instructions carry mandatory layouts in their shape. + if (instruction->opcode() != HloOpcode::kInfeed && + !IsLayoutConstrainedCustomCall(instruction)) { LayoutUtil::ClearLayout(instruction->mutable_shape()); } } @@ -1654,6 +1661,18 @@ Status LayoutAssignment::RunOnComputation( TF_RETURN_IF_ERROR( ConstrainChannelLayouts(computation, channel_constraints)); } + + // Copy the root instruction's result if its layout does not match the result + // layout constraint. + if (constraints.ResultLayout() != nullptr && + !constraints.ResultLayout()->MatchesLayoutInShape( + computation->root_instruction()->shape())) { + TF_ASSIGN_OR_RETURN( + HloInstruction * new_root, + CreateCopyWithNewLayout(constraints.ResultLayout()->shape(), + computation->root_instruction())); + computation->set_root_instruction(new_root); + } return Status::OK(); } @@ -1709,6 +1728,30 @@ Status LayoutAssignment::ConstrainChannelLayouts( ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), {0}); *send_shape = shape; } + } else if (instruction->IsCrossModuleAllReduce()) { + const Layout* layout = + get_channel_constraints(instruction) + ->ConstrainChannel(instruction->all_reduce_id().value(), + instruction->shape().layout()); + if (layout != nullptr) { + // We found an already constrained layout which does not match the one + // the channel wants to impose. Either add a new kCopy, or use the + // existing one to marshal the correct shape. + HloInstruction* operand = instruction->mutable_operand(0); + Shape shape = operand->shape(); + *shape.mutable_layout() = *layout; + if (operand->opcode() != HloOpcode::kCopy) { + HloInstruction* copy = operand->parent()->AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand)); + RegisterAddedCopy(copy); + SetupCopiedInstruction(*operand, copy, {}); + TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy)); + operand = copy; + } else { + *operand->mutable_shape() = shape; + } + *instruction->mutable_shape() = shape; + } } } return Status::OK(); @@ -1752,6 +1795,18 @@ StatusOr LayoutAssignment::Run(HloModule* module) { } TF_RETURN_IF_ERROR(Init()); + // Verify computation layout is sane. + const HloComputation* entry = module->entry_computation(); + TF_RET_CHECK(entry_computation_layout_->parameter_count() == + entry->num_parameters()); + for (int64 i = 0; i < entry->num_parameters(); ++i) { + TF_RET_CHECK( + ShapeUtil::Compatible(entry_computation_layout_->parameter_shape(i), + entry->parameter_instruction(i)->shape())); + } + TF_RET_CHECK(ShapeUtil::Compatible(entry_computation_layout_->result_shape(), + entry->root_instruction()->shape())); + // We do two passes. The first one we pass a nullptr ComputationLayout to // the RunOnComputation() calls (for non entry computations), and we register // the ComputationLayout which are naturally flowing in DFS fashion to the @@ -1803,7 +1858,8 @@ StatusOr LayoutAssignment::Run(HloModule* module) { return true; } -bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout( +/* static */ +bool LayoutAssignment::InstructionCanChangeLayout( const HloInstruction* instruction) { switch (instruction->opcode()) { case HloOpcode::kAbs: @@ -1822,7 +1878,6 @@ bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout( case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: - case HloOpcode::kCustomCall: case HloOpcode::kDivide: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: @@ -1856,6 +1911,7 @@ bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout( case HloOpcode::kRemainder: case HloOpcode::kReverse: case HloOpcode::kRoundNearestAfz: + case HloOpcode::kScatter: case HloOpcode::kSelect: case HloOpcode::kSelectAndScatter: case HloOpcode::kShiftLeft: @@ -1869,7 +1925,7 @@ bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout( case HloOpcode::kTanh: case HloOpcode::kTupleSelect: case HloOpcode::kWhile: - return true; + return false; case HloOpcode::kBatchNormGrad: case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormTraining: @@ -1879,6 +1935,7 @@ bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout( case HloOpcode::kConstant: case HloOpcode::kConvolution: case HloOpcode::kCopy: + case HloOpcode::kCustomCall: case HloOpcode::kDomain: case HloOpcode::kDot: case HloOpcode::kFusion: @@ -1893,14 +1950,13 @@ bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout( case HloOpcode::kReduce: case HloOpcode::kReshape: case HloOpcode::kRng: - case HloOpcode::kScatter: case HloOpcode::kSend: case HloOpcode::kSendDone: case HloOpcode::kAfterAll: case HloOpcode::kTrace: case HloOpcode::kTranspose: case HloOpcode::kTuple: - return false; + return true; } } diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index e29c199c42a4878daaf2eeb86b6909d6d3ff920e..cb56f4cd19ded036ef521a579eb7d6ea7f3b6268 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -25,6 +25,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -38,8 +40,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -228,8 +228,8 @@ class LayoutConstraints { // Array-shaped buffers which have not yet been constrained. std::set unconstrained_buffer_ids_; - mutable tensorflow::gtl::FlatMap> + mutable absl::flat_hash_map> buffer_sets_cache_; HloComputation* computation_; @@ -286,6 +286,11 @@ class LayoutAssignment : public HloModulePass { // entry_computation_layout is modified to populate a layout for the result in // the case that no particular layout is requested. // + // instruction_can_change_layout_func is a function object that determines + // whether an instruction can change layouts. An instruction not being able to + // change layout means that it requires operands with the same rank as the + // output to have the same layout as the output. + // // channel_constraints is both an input and output. Any sends or recvs that // are present in channel_constraints will be laid out as constrained. Any // unconstrained sends or recvs will be laid out as locally optimal and their @@ -295,6 +300,8 @@ class LayoutAssignment : public HloModulePass { // within any module passed to `Run`. explicit LayoutAssignment( ComputationLayout* entry_computation_layout, + std::function + instruction_can_change_layout_func = InstructionCanChangeLayout, ChannelLayoutConstraints* channel_constraints = nullptr); ~LayoutAssignment() override {} absl::string_view name() const override { return "layout-assignment"; } @@ -303,10 +310,10 @@ class LayoutAssignment : public HloModulePass { // (any layouts were changed). StatusOr Run(HloModule* module) override; - // Returns true if the instruction requires that operands with the same rank - // as the output have to have the same layout as the output. - virtual bool InstructionRequiresInputLayoutEqualToOutputLayout( - const HloInstruction* instruction); + // Determines whether an instruction can change layouts. An instruction not + // being able to change layout means that it requires operands with the same + // rank as the output to have the same layout as the output. + static bool InstructionCanChangeLayout(const HloInstruction* instruction); protected: // These methods, invoked by PropagateConstraints, propagate a layout @@ -326,19 +333,6 @@ class LayoutAssignment : public HloModulePass { const ResultLayoutConstraint& layout_constraint, LayoutConstraints* constraints); - // By default LayoutAssignment ensures that inputs and outputs of CustomCalls - // have the "major-first" layout (i.e. {n, n-1, ..., 0}). - // - // If this function returns true, LayoutAssignment does not set a layout for - // the given CustomCall. It's up to the backend to set one in - // AddBackendConstraints, if necessary. - // - // Precondition: instruction->opcode() == HloOpcode::kCustomCall. - virtual bool CustomCallRequiresMajorFirstLayout( - const HloInstruction* /*instruction*/) { - return true; - } - // Called after layouts of an instruction have been finalized to allow // subclasses to check for platform specific assumptions. virtual Status Verify(const HloInstruction* instruction) { @@ -504,7 +498,7 @@ class LayoutAssignment : public HloModulePass { // Every copy added to the module by the layout assignment pass is registered // here. - tensorflow::gtl::FlatSet added_copies_; + absl::flat_hash_set added_copies_; // The pointer to the channel layout constraints passed in with the // constructor. If not nullptr, this is an input/output argument. @@ -521,8 +515,10 @@ class LayoutAssignment : public HloModulePass { // The set of HLO instructions which lacked any layout constraint, thus // receiving propagated default layouts. - tensorflow::gtl::FlatSet - unconstrained_layout_instructions_; + absl::flat_hash_set unconstrained_layout_instructions_; + + std::function + instruction_can_change_layout_func_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 752a61476dd7892a2b7f531c4057015f48fc4758..a831751fa96f8cef233e16fe02378ac036efc8ab 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -55,7 +55,8 @@ class LayoutAssignmentTest : public HloVerifiedTestBase { ComputationLayout* entry_computation_layout, ChannelLayoutConstraints* channel_constraints = nullptr) { LayoutAssignment layout_assignment( - entry_computation_layout, /*channel_constraints=*/channel_constraints); + entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout, + /*channel_constraints=*/channel_constraints); EXPECT_IS_OK(layout_assignment.Run(module).status()); } @@ -64,6 +65,27 @@ class LayoutAssignmentTest : public HloVerifiedTestBase { FindInstruction(module, name)->shape().layout().minor_to_major(); return std::vector(minor_to_major.begin(), minor_to_major.end()); } + + void ExpectLayoutIs(const Shape& shape, + absl::Span minor_to_major) { + const Layout expected = LayoutUtil::MakeLayout(minor_to_major); + EXPECT_TRUE(LayoutUtil::Equal(shape.layout(), expected)) + << "Expected layout " << expected << ", actual " << shape.layout(); + } + + void ExpectTupleLayoutIs( + const Shape& shape, + std::initializer_list> minor_to_majors) { + int i = 0; + for (const absl::Span minor_to_major : minor_to_majors) { + const Layout expected = LayoutUtil::MakeLayout(minor_to_major); + const Layout& actual = ShapeUtil::GetTupleElementShape(shape, i).layout(); + EXPECT_TRUE(LayoutUtil::Equal(actual, expected)) + << "Expected tuple element " << i << " layout " << expected + << ", actual " << actual; + ++i; + } + } }; TEST_F(LayoutAssignmentTest, ComputationLayout) { @@ -860,6 +882,50 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); } +TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) { + // Pin non matching layouts to parameter and root. + const char* module_str = R"( + HloModule test_module + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY entry_computation { + param = (f32[2,2]) parameter(0) + gte = f32[2,2] get-tuple-element(param), index=0 + ar.0 = f32[2,2] cross-replica-sum(gte), + all_reduce_id=0, replica_groups={{0}}, to_apply=add, + sharding={maximal device=0} + const = f32[2,2] constant(f32[2,2]{{0,1},{2,3}}) + ROOT ar.1 = f32[2,2] cross-replica-sum(const), + all_reduce_id=0, replica_groups={{0}}, to_apply=add, + sharding={maximal device=1} + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape()); + Shape param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})}); + TF_ASSERT_OK( + computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape( + param_shape)); + computation_layout.mutable_result_layout()->ResetLayout( + LayoutUtil::MakeLayout({1, 0})); + + ChannelLayoutConstraints channel_constraints; + AssignLayouts(module.get(), &computation_layout, &channel_constraints); + + EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(module.get(), "ar.0"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(module.get(), "ar.1"), ElementsAre(0, 1)); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root->shape().layout().minor_to_major(), ElementsAre(1, 0)); +} + TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { const char* module_str = R"( HloModule CopySliceOperandToAvoidImplicitLayoutChange @@ -998,5 +1064,232 @@ TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) { op::ShapeWithLayout(shape_copy)))); } +TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) { + // The first infeed uses layout {0,1}, while the second uses layout {1,0}. + // The mismatch forces a copy of the tuple. The tuple contains a token, so + // layout assignment will fail if it tries to copy the whole tuple. + const char* module_str = R"( + HloModule TupleCopyOnLayoutMismatch + + condition.1 (tup: (s32[], token[], f32[512,1024]{0,1})) -> pred[] { + tup.1 = (s32[], token[], f32[512,1024]{0,1}) parameter(0) + counter.1 = s32[] get-tuple-element(tup.1), index=0 + five = s32[] constant(5) + ROOT lt = pred[] less-than(counter.1, five) + } + + body.2 (tup: (s32[], token[], f32[512,1024]{0,1})) -> (s32[], token[], f32[512,1024]{0,1}) { + tup.2 = (s32[], token[], f32[512,1024]{0,1}) parameter(0) + counter.2 = s32[] get-tuple-element(tup.2), index=0 + tok.2 = token[] get-tuple-element(tup.2), index=1 + + ifeed.2 = (f32[512,1024]{1,0}, token[]) infeed(tok.2) + next_tok = token[] get-tuple-element(ifeed.2), index=1 + next_buf = f32[512,1024]{1,0} get-tuple-element(ifeed.2), index=0 + + one = s32[] constant(1) + next_counter = s32[] add(counter.2, one) + ROOT tup = (s32[], token[], f32[512,1024]{0,1}) tuple(next_counter, next_tok, next_buf) + } + + ENTRY main () -> f32[512,1024]{0,1} { + start_tok = token[] after-all() + + ifeed.3 = (f32[512,1024]{0,1}, token[]) infeed(start_tok) + itok = token[] get-tuple-element(ifeed.3), index=1 + ibuf = f32[512,1024]{0,1} get-tuple-element(ifeed.3), index=0 + + zero = s32[] constant(0) + itup = (s32[], token[], f32[512,1024]{0,1}) tuple(zero, itok, ibuf) + + loop = (s32[], token[], f32[512,1024]{0,1}) while(itup), condition=condition.1, body=body.2 + ROOT result = f32[512,1024]{0,1} get-tuple-element(loop), index=2 + } + )"; + + ParseAndVerifyModule(module_str); + ComputationLayout computation_layout( + module().entry_computation()->ComputeProgramShape()); + + // Sanity check to verify that there's a layout mismatch. + EXPECT_THAT(LayoutOf(&module(), "ibuf"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(&module(), "next_buf"), ElementsAre(1, 0)); + + AssignLayouts(&module(), &computation_layout); + + // Make sure that layout assignment did not magically eliminate the mismatch, + // in which case the test didn't prove anything. + EXPECT_THAT(LayoutOf(&module(), "ibuf"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(&module(), "next_buf"), ElementsAre(1, 0)); +} + +TEST_F(LayoutAssignmentTest, CustomCallNotLayoutConstrained) { + const char* module_str = R"( +HloModule CustomCallNotLayoutConstrained + +ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] { + %p = f32[42,2,3] parameter(0) + ROOT %custom-call = f32[1,2,3,4] custom-call(f32[42,2,3] %p), custom_call_target="baz" +} +)"; + // Try with a couple different layouts. In each case the custom calls operand + // and result layout should match that of the computation. + { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 2, 1})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 2, 0, 1})); + AssignLayouts(module.get(), &computation_layout); + + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, op::CustomCall(op::Parameter())); + ExpectLayoutIs(root->shape(), {3, 2, 0, 1}); + ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1}); + } + { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 1, 2})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {0, 2, 3, 1})); + AssignLayouts(module.get(), &computation_layout); + + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, op::CustomCall(op::Parameter())); + ExpectLayoutIs(root->shape(), {0, 2, 3, 1}); + ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2}); + } +} + +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrained) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrained + +ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] { + %p0 = f32[4,4] parameter(0) + %p1 = f32[2,3] parameter(1) + ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(f32[4,4] %p0, f32[2,3] %p1), custom_call_target="baz", operand_layout_constraints={f32[4,4]{0,1}, f32[2,3]{1,0}} +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); + *computation_layout.mutable_parameter_layout(1) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); + AssignLayouts(module.get(), &computation_layout); + + // The custom call should be partially encapsulated in kCopy instructions + // because of the layout mismatches. + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::Copy(op::CustomCall(op::Copy(), op::Parameter()))); + + const HloInstruction* custom_call = + module->entry_computation()->root_instruction()->operand(0); + ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); + ExpectLayoutIs(custom_call->operand(0)->shape(), {0, 1}); + ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0}); +} + +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedZeroOperands) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrainedZeroOperands + +ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] { + ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(), custom_call_target="baz", operand_layout_constraints={} +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); + AssignLayouts(module.get(), &computation_layout); + + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::Copy(op::CustomCall())); + + const HloInstruction* custom_call = + module->entry_computation()->root_instruction()->operand(0); + ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); +} + +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleOperand) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrainedTupleOperand + +ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] { + %p0 = f32[4,4] parameter(0) + %p1 = f32[2,3] parameter(1) + %tuple = (f32[4,4], f32[2,3]) tuple(%p0, %p1) + ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(%tuple), custom_call_target="baz", operand_layout_constraints={(f32[4,4]{1,0}, f32[2,3]{0,1})} +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); + *computation_layout.mutable_parameter_layout(1) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); + AssignLayouts(module.get(), &computation_layout); + + HloInstruction* root = module->entry_computation()->root_instruction(); + ExpectLayoutIs(root->shape(), {2, 1, 0, 3}); + + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::Copy(op::CustomCall(op::Tuple()))); + + const HloInstruction* custom_call = + module->entry_computation()->root_instruction()->operand(0); + ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); + ExpectTupleLayoutIs(custom_call->operand(0)->shape(), {{1, 0}, {0, 1}}); +} + +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleResult) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrainedTupleResult + +ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0}, f32[2,3]{0,1}) { + %p0 = f32[4,4] parameter(0) + ROOT %custom-call = (f32[4,4]{1,0}, f32[2,3]{0,1}) custom-call(%p0), custom_call_target="baz", operand_layout_constraints={f32[4,4]{1,0}} +} +)"; + // Try with a couple different layouts. In each case the custom calls operand + // and result layout should match that of the computation. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); + *computation_layout.mutable_result_layout() = + ShapeLayout(ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}), + ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})})); + AssignLayouts(module.get(), &computation_layout); + + ExpectTupleLayoutIs(module->result_shape(), {{1, 0}, {1, 0}}); + + const HloInstruction* custom_call = + FindInstruction(module.get(), "custom-call"); + ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_compiler.cc b/tensorflow/compiler/xla/service/llvm_compiler.cc index b17c9d504501a907e27d5152e0082799e87443c7..d287aa4ec7bbcd11f51ea07cd2a1572e59f0d6c6 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.cc +++ b/tensorflow/compiler/xla/service/llvm_compiler.cc @@ -21,8 +21,24 @@ limitations under the License. #endif namespace xla { +Status LLVMCompiler::RunHloPassesOnModuleGroup( + HloModuleGroup* module_group, se::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator) { + return Unimplemented( + "Model partitioning not implemented for the CPU/GPU compilers!"); +} + +StatusOr>> +LLVMCompiler::RunBackendOnModuleGroup( + std::unique_ptr module_group, + std::vector> stream_exec, + DeviceMemoryAllocator* device_allocator) { + return Unimplemented( + "Model partitioning not implemented for the CPU/GPU compilers!"); +} + StatusOr>> LLVMCompiler::Compile( - std::vector> modules, + std::unique_ptr module_group, std::vector> stream_execs, DeviceMemoryAllocator* device_allocator) { // Tensorflow tries to enable the following behaviors in all its threads: @@ -38,6 +54,8 @@ StatusOr>> LLVMCompiler::Compile( tensorflow::port::ScopedDontFlushDenormal dont_flush_denormals; std::vector> result; + std::vector> modules = + module_group->ConsumeModules(); for (size_t i = 0; i < modules.size(); i++) { if (stream_execs[i].size() != 1) { return Unimplemented( diff --git a/tensorflow/compiler/xla/service/llvm_compiler.h b/tensorflow/compiler/xla/service/llvm_compiler.h index f1c623508c5307f2b1c036d3ec6823b75c7eda13..86abd5da0189feb0eadfde3d6dbab446eb2be900 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.h +++ b/tensorflow/compiler/xla/service/llvm_compiler.h @@ -69,8 +69,17 @@ class LLVMCompiler : public Compiler { using Compiler::RunBackend; using Compiler::RunHloPasses; + Status RunHloPassesOnModuleGroup( + HloModuleGroup* module_group, se::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator) override; + + StatusOr>> RunBackendOnModuleGroup( + std::unique_ptr module_group, + std::vector> stream_exec, + DeviceMemoryAllocator* device_allocator) override; + StatusOr>> Compile( - std::vector> modules, + std::unique_ptr module_group, std::vector> stream_execs, DeviceMemoryAllocator* device_allocator) override; diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 540bbb7c7a74f65ab70f4c6704d6600db2adbb60..5f7ad81d82978d0a752b33d12b72e16f0c1c6826 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -38,6 +38,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:logical_buffer", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@llvm//:core", ], @@ -202,7 +204,6 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:partition_assignment", "//tensorflow/core:lib", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", "@llvm//:core", "@llvm//:support", ], diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index e5370eca56f2e3a891523ba2b72961d66ec809aa..643ecd0fbaa546c551097b29e74ccd49418e1466 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" -#include +#include #include "llvm/IR/MDBuilder.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -164,9 +164,7 @@ llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer( add_buffers_to_worklist(operand); } - tensorflow::gtl::FlatSet - buffers; + std::set buffers; for (const LogicalBuffer* buffer : worklist) { // Skip buffers which cannot be added to the noalias set. if (!assignment.HasAllocation(*buffer) || diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h index 8d9fa99d82b4e49b653d9f05cc9baa5e3fdcefa6..2b46b3c3964b15548dbacc8b0ada0047a0fa85b6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h @@ -16,14 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_ +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { namespace llvm_ir { @@ -77,14 +76,14 @@ class AliasAnalysis { // A map from a buffer slice to metadata corresponding to its alias.scope // metadata. The index kParameterAliasSet is used to hold aliasing // information for parameters. - tensorflow::gtl::FlatMap + absl::flat_hash_map alias_scope_metadata_; // A map from a buffer slice to metadata corresponding to its noalias // metadata. - tensorflow::gtl::FlatMap + absl::flat_hash_map noalias_metadata_; }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index 944c79580c133906cd431722fd6b29e6aee5f918..05ba4a40da413f0e774214e55ef69d023afc48e2 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -15,9 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" +#include + // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" #include "absl/strings/string_view.h" -#include "absl/types/optional.h" #include "llvm/ADT/APInt.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" @@ -43,7 +44,7 @@ namespace { void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index, const IrArray::Index& compare_keys_index, const IrArray& keys_array, - const absl::optional& values_array, + const std::vector& values_arrays, llvm::IRBuilder<>* b) { // if (is_smaller_index && // compare_keys[dimension_to_sort] < dimension_to_sort_bound) @@ -100,19 +101,18 @@ void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index, // Swap key1 with key2. keys_array.EmitWriteArrayElement(keys_index, key2, b); keys_array.EmitWriteArrayElement(compare_keys_index, key1, b); - if (values_array.has_value()) { + for (const auto& values_array : values_arrays) { // Also swap the values. - auto value1 = values_array.value().EmitReadArrayElement(keys_index, b); - auto value2 = - values_array.value().EmitReadArrayElement(compare_keys_index, b); - values_array.value().EmitWriteArrayElement(keys_index, value2, b); - values_array.value().EmitWriteArrayElement(compare_keys_index, value1, b); + auto value1 = values_array.EmitReadArrayElement(keys_index, b); + auto value2 = values_array.EmitReadArrayElement(compare_keys_index, b); + values_array.EmitWriteArrayElement(keys_index, value2, b); + values_array.EmitWriteArrayElement(compare_keys_index, value1, b); } } } // namespace Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, - const absl::optional& values_array, + const std::vector& values_arrays, absl::string_view name, llvm::Value* xor_mask, llvm::IRBuilder<>* b, const gpu::LaunchDimensions* launch_dimensions) { @@ -162,7 +162,7 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, compare_keys_index[dimension_to_sort] = b->CreateXor(compare_index[0], xor_mask); EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index, - keys_array, values_array, b); + keys_array, values_arrays, b); return Status::OK(); }; if (launch_dimensions != nullptr) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h index 527ed10374ce9482045a8459e38fd041e0e83001..2f3bcda2307bcbb35a03b9e71dbbe44e366b3820 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h @@ -16,8 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_ +#include + #include "absl/strings/string_view.h" -#include "absl/types/optional.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" @@ -31,7 +32,7 @@ namespace llvm_ir { // implements the inner loop of BitonicSort. If 'launch_dimensions' is nullptr, // the inner compare loop will not be parallelized. Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, - const absl::optional& values_array, + const std::vector& values_arrays, absl::string_view name, llvm::Value* xor_mask, llvm::IRBuilder<>* b, const gpu::LaunchDimensions* launch_dimensions); diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 0d0fb7946ae6815905491ca55652d7d0ab278a3c..cca37556173bb95ef062b59ab0a4bf9ca7c496fe 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -144,8 +144,8 @@ StatusOr> LocalService::CompileExecutable( const absl::Span argument_layouts, const ExecutableBuildOptions& build_options) { const HloModuleProto& proto = computation.proto(); - TF_RET_CHECK(proto.has_program_shape()); - const ProgramShape& program_shape = proto.program_shape(); + TF_RET_CHECK(proto.has_host_program_shape()); + const ProgramShape& program_shape = proto.host_program_shape(); // Validate incoming layouts. if (argument_layouts.size() != program_shape.parameters_size()) { diff --git a/tensorflow/compiler/xla/service/inliner.cc b/tensorflow/compiler/xla/service/map_inliner.cc similarity index 76% rename from tensorflow/compiler/xla/service/inliner.cc rename to tensorflow/compiler/xla/service/map_inliner.cc index 5fd779ebf9b59e34a0844cc3a898bb72ce6044ee..2200ef054a6993fb884751643ab1fb5ab83efe05 100644 --- a/tensorflow/compiler/xla/service/inliner.cc +++ b/tensorflow/compiler/xla/service/map_inliner.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/inliner.h" +#include "tensorflow/compiler/xla/service/map_inliner.h" #include #include @@ -32,10 +32,10 @@ limitations under the License. namespace xla { -// InlinerVisitor traverses the HLO computation and inlines maps. -class InlinerVisitor : public DfsHloVisitorWithDefault { +// MapInlinerVisitor traverses the HLO computation and inlines maps. +class MapInlinerVisitor : public DfsHloVisitorWithDefault { public: - explicit InlinerVisitor(HloComputation* computation) + explicit MapInlinerVisitor(HloComputation* computation) : computation_(computation) {} // Default visitor action is to do nothing and return OK. @@ -49,48 +49,44 @@ class InlinerVisitor : public DfsHloVisitorWithDefault { StatusOr Run(HloComputation* computation); private: - // Current HloComputation instance the InlinerVisitor is traversing. + // Current HloComputation instance the MapInlinerVisitor is traversing. HloComputation* computation_; // Whether algebraic simplification has occurred. bool changed_ = false; }; -StatusOr InlinerVisitor::Run(HloComputation* computation) { +StatusOr MapInlinerVisitor::Run(HloComputation* computation) { changed_ = false; computation_ = computation; TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(this)); return changed_; } -Status InlinerVisitor::HandleMap(HloInstruction* map) { +Status MapInlinerVisitor::HandleMap(HloInstruction* map) { HloComputation* function = map->to_apply(); HloInstruction& root = *function->root_instruction(); - // TODO(b/29249531): Add DCE pass to remove unused HloComputations. // Only inlining functions that are simply a single operation until a better // profitability model for inlining is defined. if (hlo_query::AllOperandsAreParameters(root)) { if (root.opcode() == HloOpcode::kFusion || - root.opcode() == HloOpcode::kParameter || root.opcode() == HloOpcode::kTrace) { // Cloning not supported for these instructions. return Status::OK(); } VLOG(10) << "inlining map({X ... Y}, op) => : op(X ... Y) with function " << root.ToShortString(); - // If the input is a constant then the shape of the constant could be - // different than the map shape. Hence, a broadcast is needed, else the - // cloned operand with new shape and operands work. - if (root.opcode() != HloOpcode::kConstant) { - std::vector params; - for (int64 o = 0; o < root.operands().size(); o++) { - params.push_back(map->operands()[root.operand(o)->parameter_number()]); - } - HloInstruction* placed_instruction = computation_->AddInstruction( - root.CloneWithNewOperands(map->shape(), params)); + if (root.opcode() == HloOpcode::kParameter) { + // If the root is a parameter, then use the corresponding operand as the + // result of the computation. TF_RETURN_IF_ERROR( - computation_->ReplaceInstruction(map, placed_instruction)); - } else { + map->ReplaceAllUsesWith(map->operands()[root.parameter_number()])); + TF_RETURN_IF_ERROR(computation_->RemoveInstruction(map)); + } else if (root.opcode() == HloOpcode::kConstant) { + // If the input is a constant then the shape of the constant could be + // different than the map shape. Hence, a broadcast is needed, else the + // cloned operand with new shape and operands work. + // // The constant is in an embedded computation and needs to be recreated // as part of the computation that the broadcast is inserted into. HloInstruction* constant = computation_->AddInstruction(root.Clone()); @@ -98,6 +94,15 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) { HloInstruction::CreateBroadcast(map->shape(), constant, {})); TF_RETURN_IF_ERROR( computation_->ReplaceInstruction(map, placed_instruction)); + } else { + std::vector params; + for (int64 o = 0; o < root.operands().size(); o++) { + params.push_back(map->operands()[root.operand(o)->parameter_number()]); + } + HloInstruction* placed_instruction = computation_->AddInstruction( + root.CloneWithNewOperands(map->shape(), params)); + TF_RETURN_IF_ERROR( + computation_->ReplaceInstruction(map, placed_instruction)); } changed_ = true; return Status::OK(); @@ -106,8 +111,8 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) { return Status::OK(); } -StatusOr Inliner::Run(HloModule* module) { - InlinerVisitor visitor(/*computation=*/nullptr); +StatusOr MapInliner::Run(HloModule* module) { + MapInlinerVisitor visitor(/*computation=*/nullptr); bool changed = false; for (HloComputation* computation : module->computations()) { TF_ASSIGN_OR_RETURN(bool computation_changed, visitor.Run(computation)); diff --git a/tensorflow/compiler/xla/service/inliner.h b/tensorflow/compiler/xla/service/map_inliner.h similarity index 59% rename from tensorflow/compiler/xla/service/inliner.h rename to tensorflow/compiler/xla/service/map_inliner.h index e20af08fb7329c3646903761ee081e421daa5712..b67911811846e2250068921ef252b7df596d4016 100644 --- a/tensorflow/compiler/xla/service/inliner.h +++ b/tensorflow/compiler/xla/service/map_inliner.h @@ -13,27 +13,27 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" namespace xla { -// A pass which performs inlining. Which can result, for example, in functions -// that were previously being mapped by Map instead directly applied to the -// forwarded operands (i.e., map({X, Y}, max) -> max(X, Y)). -class Inliner : public HloModulePass { +// A pass which performs map inlining. This replaces kMap instructions with +// their equivalent sequence of array operations. For example: +// map({X, Y}, add) -> add(X, Y)). +class MapInliner : public HloModulePass { public: - ~Inliner() override = default; - absl::string_view name() const override { return "inline"; } + ~MapInliner() override = default; + absl::string_view name() const override { return "map-inline"; } - // Run inlining on the given computation. Returns whether the computation was - // changed. + // Run map inlining on the given computation. Returns whether the computation + // was changed. StatusOr Run(HloModule* module) override; }; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_ diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/map_inliner_test.cc similarity index 78% rename from tensorflow/compiler/xla/service/inliner_test.cc rename to tensorflow/compiler/xla/service/map_inliner_test.cc index 7e967f035c1054e22d10790188a5a232ca8e751a..84059dd0f71ee8fc0a25703cbab2268d7dc149a8 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/map_inliner_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/inliner.h" +#include "tensorflow/compiler/xla/service/map_inliner.h" #include #include @@ -35,10 +35,10 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using InlinerTest = HloVerifiedTestBase; +using MapInlinerTest = HloVerifiedTestBase; // Test that `map` with `max` is transformed to `max` -TEST_F(InlinerTest, MapMax) { +TEST_F(MapInlinerTest, MapMax) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); auto max_builder = HloComputation::Builder(TestName()); @@ -63,7 +63,7 @@ TEST_F(InlinerTest, MapMax) { hlo_module->AddEmbeddedComputation(std::move(max_f32)); hlo_module->AddEntryComputation(std::move(computation)); - Inliner inliner; + MapInliner inliner; EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), op::Maximum(lhs, rhs)); @@ -75,7 +75,7 @@ TEST_F(InlinerTest, MapMax) { } // Test that `constant` function is changed to `broadcast`. -TEST_F(InlinerTest, MapConstant) { +TEST_F(MapInlinerTest, MapConstant) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); auto const2_builder = HloComputation::Builder(TestName()); @@ -97,7 +97,7 @@ TEST_F(InlinerTest, MapConstant) { hlo_module->AddEmbeddedComputation(std::move(const2_f32)); hlo_module->AddEntryComputation(std::move(computation)); HloInstruction* root = hlo_module->entry_computation()->root_instruction(); - Inliner inliner; + MapInliner inliner; EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); root = hlo_module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Broadcast(op::Constant())); @@ -108,7 +108,7 @@ TEST_F(InlinerTest, MapConstant) { EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } -TEST_F(InlinerTest, MapSubtractOppositeOrder) { +TEST_F(MapInlinerTest, MapSubtractOppositeOrder) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); // Note that the parameter ordinals are in the opposite order to their @@ -135,7 +135,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { hlo_module->AddEmbeddedComputation(std::move(max_f32)); hlo_module->AddEntryComputation(std::move(computation)); - Inliner inliner; + MapInliner inliner; EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), op::Subtract(rhs, lhs)); @@ -146,6 +146,36 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } +TEST_F(MapInlinerTest, MapParameter) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + + auto param_builder = HloComputation::Builder(TestName()); + param_builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32, "p0")); + param_builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "p1")); + auto param_f32 = param_builder.Build(); + + auto builder = HloComputation::Builder("MapParamFunction"); + auto lhs = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + auto rhs = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4))); + builder.AddInstruction( + HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, param_f32.get())); + + auto computation = builder.Build(); + auto hlo_module = CreateNewVerifiedModule(); + hlo_module->AddEmbeddedComputation(std::move(param_f32)); + hlo_module->AddEntryComputation(std::move(computation)); + + MapInliner inliner; + EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), rhs); + + // Verify execution on CPU. + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); + auto expected = LiteralUtil::CreateR0(4); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index b9ec31c4977be0c31dfff01a0c495902191d7d5b..2ca527bc4cb8f66a085c1e6a7cbb8ddaedbfc07e 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -15,10 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/multi_output_fusion.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -50,7 +50,7 @@ StatusOr MultiOutputFusion::Run(HloModule* module) { all_fusion_candidates_.push_back(instruction); std::vector candidates; - tensorflow::gtl::FlatSet candidates_set; + absl::flat_hash_set candidates_set; VLOG(10) << "Looking at instruction: " << instruction->name(); for (auto operand : instruction->operands()) { // Filter out the non-interesting instructions -- they @@ -172,7 +172,7 @@ void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) { // Update the fusible list for fusion. Variable new_fusibles keeps // track of the new or changed entries. std::vector> new_fusibles; - tensorflow::gtl::FlatSet in_list; + absl::flat_hash_set in_list; auto it = fusion_node.fusibles.begin(); while (it != fusion_node.fusibles.end()) { HloInstruction* instr = it->first; diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index 0344626b26b2cd1d659657c51636266706d17afb..9508ab2ed1d38ec40983d8892ec8875b848fb21b 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -126,7 +127,7 @@ class MultiOutputFusion : public HloModulePass { std::vector candidates_; // A map that maps an instruction to the index_. - tensorflow::gtl::FlatMap candidates_index_; + absl::flat_hash_map candidates_index_; // The reachability map of current computation. std::unique_ptr reachability_; diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h index 6dd89c240f81c9f0ccac66e50c7f244bfd5429f1..8909d0f4fea801e43ab06a75e8933d24a74146bc 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.h +++ b/tensorflow/compiler/xla/service/name_uniquer.h @@ -18,10 +18,10 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -69,7 +69,7 @@ class NameUniquer { int64 next_ = 0; // Set of all the identifiers which has been used. - tensorflow::gtl::FlatSet used_; + absl::flat_hash_set used_; }; // The string to use to separate the prefix of the name from the uniquing @@ -78,7 +78,7 @@ class NameUniquer { // Map from name prefix to the generator data structure which tracks used // identifiers and generates new ones. - tensorflow::gtl::FlatMap generated_names_; + absl::flat_hash_map generated_names_; TF_DISALLOW_COPY_AND_ASSIGN(NameUniquer); }; diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h index 4bb22428f3d66f27d268ac4490c6e2613966cbed..0b4e82e8d606cf2cacfab42d07c2201939d5e10b 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index b27a92f2a0761a2bccd97eb2c0467ead27565c37..75465359f8f37e56369c0976ba7434e3c3f202cc 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -207,7 +207,7 @@ Status Service::ValidateResultShape(const Shape& client_shape, StatusOr>> Service::ResolveAndValidateArguments( absl::Span arguments, - absl::Span stream_executors) { + absl::Span stream_executors) const { CHECK_EQ(options_.number_of_replicas(), stream_executors.size()); std::vector> replicated_arguments; replicated_arguments.resize(options_.number_of_replicas()); @@ -341,19 +341,19 @@ StatusOr>> Service::BuildExecutables( } CHECK_EQ(module_protos.size(), module_configs.size()); - std::vector> modules; + auto module_group = + absl::make_unique(module_protos[0]->name()); for (int64 i = 0; i < module_protos.size(); ++i) { const HloModuleProto* proto = module_protos[i]; const HloModuleConfig& config = *module_configs[i]; - TF_ASSIGN_OR_RETURN(auto module, - HloModule::CreateFromProto(*proto, config)); - modules.push_back(std::move(module)); + TF_ASSIGN_OR_RETURN(auto module, CreateModuleFromProto(*proto, config)); + module_group->push_back(std::move(module)); } TF_ASSIGN_OR_RETURN( std::vector> executables, - backend->compiler()->Compile(std::move(modules), std::move(executors), - device_allocator)); + backend->compiler()->Compile(std::move(module_group), + std::move(executors), device_allocator)); for (size_t i = 0; i < module_protos.size(); ++i) { if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) { @@ -590,7 +590,7 @@ StatusOr> Service::GetExecutors( StatusOr>> Service::GetArguments( const ExecutionOptions& execution_options, - absl::Span arguments) { + absl::Span arguments) const { // Resolve the allocations for the arguments of the computation, and create // a vector of device memory offsets for the arguments from the allocations. // In the case of partitioned computations, assume all arguments go on the @@ -634,7 +634,7 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, arg->requests(i).execution_options(); const ExecuteGraphRequest& request = arg->requests(i); TF_RET_CHECK(request.has_computation()) << "computations may not be empty"; - TF_RET_CHECK(request.computation().has_program_shape()) + TF_RET_CHECK(request.computation().has_host_program_shape()) << "programe shape may not be empty"; // Get the executors. @@ -651,7 +651,7 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, // replica 0. TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(request.computation().program_shape(), + CreateModuleConfig(request.computation().host_program_shape(), replicated_arguments.front(), request.execution_options())); VLOG(3) @@ -810,7 +810,7 @@ StatusOr> Service::BuildExecutable( } TF_ASSIGN_OR_RETURN(std::unique_ptr module, - HloModule::CreateFromProto(module_proto, *module_config)); + CreateModuleFromProto(module_proto, *module_config)); TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*module)); @@ -836,7 +836,7 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, if (!arg->has_computation()) { return InvalidArgument("computations may not be empty"); } - if (!arg->computation().has_program_shape()) { + if (!arg->computation().has_host_program_shape()) { return InvalidArgument("programe shape may not be empty"); } @@ -851,10 +851,11 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, std::vector> replicated_arguments, ResolveAndValidateArguments(arg->arguments(), replicas)); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(arg->computation().program_shape(), - replicated_arguments.front(), - arg->execution_options())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(arg->computation().host_program_shape(), + replicated_arguments.front(), + arg->execution_options())); TF_ASSIGN_OR_RETURN( std::unique_ptr executable, @@ -1063,15 +1064,15 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, if (!arg->has_computation()) { return InvalidArgument("computations may not be empty"); } - if (!arg->computation().has_program_shape()) { + if (!arg->computation().has_host_program_shape()) { return InvalidArgument("program shape may not be empty"); } - if (arg->computation().program_shape().parameters_size() != 0) { + if (arg->computation().host_program_shape().parameters_size() != 0) { return InvalidArgument( "constant computation may not depend on any parameters."); } - ProgramShape program_shape = arg->computation().program_shape(); + ProgramShape program_shape = arg->computation().host_program_shape(); TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); if (arg->has_output_layout()) { TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( @@ -1081,7 +1082,7 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, HloModuleConfig config(program_shape); TF_ASSIGN_OR_RETURN(std::unique_ptr module, - HloModule::CreateFromProto(arg->computation(), config)); + CreateModuleFromProto(arg->computation(), config)); HloEvaluator evaluator; TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate( @@ -1111,14 +1112,14 @@ Status Service::GetComputationGraphStats( if (!arg->has_computation()) { return InvalidArgument("Computations may not be empty."); } - if (!arg->computation().has_program_shape()) { + if (!arg->computation().has_host_program_shape()) { return InvalidArgument("Program shape may not be empty."); } - HloModuleConfig config(arg->computation().program_shape()); + HloModuleConfig config(arg->computation().host_program_shape()); config.set_debug_options(arg->debug_options()); TF_ASSIGN_OR_RETURN(std::unique_ptr module, - HloModule::CreateFromProto(arg->computation(), config)); + CreateModuleFromProto(arg->computation(), config)); hlo_graph_dumper::MaybeDumpHloModule(*module, "computation statistics subject"); diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 1f62fad4c8079eba7013b3f647fe19bbc031fc77..8cf1a7b9f01fbb3572c6849c8b18e14174ced89f 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -191,7 +191,7 @@ class Service : public ServiceInterface { // Prepare the arguments for executing parallel. StatusOr>> GetArguments( const ExecutionOptions& execution_options, - absl::Span arguments); + absl::Span arguments) const; protected: friend class LocalExecutable; @@ -208,7 +208,7 @@ class Service : public ServiceInterface { StatusOr>> ResolveAndValidateArguments( absl::Span arguments, - absl::Span stream_executors); + absl::Span stream_executors) const; // Create a Hlo module config for the given program shape and arguments. // execution_options is optional; if not given a default is used. diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 7194b2cafd348c144a2ee83027cf78642bfaf75f..25afc23e5b41468ad5dd1abed076e399cf20f350 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -577,7 +577,7 @@ Status ValidateDotDimensionNumbers( // Check that dimension numbers are unique. auto dims_unique = [](absl::Span contracting_dims, absl::Span batch_dims) -> bool { - tensorflow::gtl::FlatSet dim_set; + absl::flat_hash_set dim_set; auto is_unique = [&dim_set](int64 i) -> bool { return dim_set.insert(i).second; }; @@ -919,6 +919,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, switch (opcode) { case HloOpcode::kMaximum: case HloOpcode::kMinimum: + return InferElementwiseBinaryOpShape(opcode, lhs, rhs, + broadcast_dimensions); + case HloOpcode::kSubtract: case HloOpcode::kAdd: case HloOpcode::kAtan2: @@ -929,6 +932,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: + if (lhs.element_type() == PRED || rhs.element_type() == PRED) { + return InvalidArgument( + "Expected element type in shape to be arithmetic type for " + "operation %s; got PRED.", + HloOpcodeString(opcode)); + } return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); @@ -1029,17 +1038,22 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, case HloOpcode::kSort: { if (operand_shapes.size() == 1) { return *operand_shapes[0]; - } else if (operand_shapes.size() == 2) { - if (!ShapeUtil::SameDimensions(*operand_shapes[0], - *operand_shapes[1])) { - return InvalidArgument( - "Sort keys and values dimensions must match. " - "Keys shape is: %s\n, Values shape is: %s", - ShapeUtil::HumanString(*operand_shapes[0]), - ShapeUtil::HumanString(*operand_shapes[1])); + } else { + for (int64 operand = 1; operand < operand_shapes.size(); ++operand) { + if (!ShapeUtil::SameDimensions(*operand_shapes[0], + *operand_shapes[operand])) { + return InvalidArgument( + "Sort keys and values dimensions must match. " + "Keys shape is: %s\n, Values shape (operand index %lld) is: %s", + ShapeUtil::HumanString(*operand_shapes[0]), operand, + ShapeUtil::HumanString(*operand_shapes[operand])); + } + } + std::vector operand_shape_values; + for (const Shape* operand_shape : operand_shapes) { + operand_shape_values.push_back(*operand_shape); } - return ShapeUtil::MakeTupleShape( - {*operand_shapes[0], *operand_shapes[1]}); + return ShapeUtil::MakeTupleShape(operand_shape_values); } return InvalidArgument("Unexpected number of operands for sort"); } @@ -2380,7 +2394,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, !std::is_permutation(dimensions.begin(), dimensions.end(), indices.begin())) { return InvalidArgument( - "Transpose dimensions not a permutation of the operand dimensions."); + "Transpose dimensions [%s] are not a permutation of the operand " + "dimensions (operand shape is %s).", + StrJoin(dimensions, ","), ShapeUtil::HumanString(operand)); } // Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However, diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 864ed43118cd066f6ce14cd808b873f137b8414a..7b65e8c1c9d2bc730c6c8550e9265b69fdde71cf 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -1618,13 +1618,37 @@ TEST_F(ShapeInferenceTest, BadSort) { auto values = ShapeUtil::MakeShape(F32, {5}); StatusOr statusor = ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values}); - ASSERT_FALSE(statusor.ok()); + EXPECT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("dimensions must match")) + << statusor.status(); +} +TEST_F(ShapeInferenceTest, BadSortValuesMismatch) { + auto keys = ShapeUtil::MakeShape(F32, {4}); + auto values_good = ShapeUtil::MakeShape(F32, {4}); + auto values_bad = ShapeUtil::MakeShape(F32, {5}); + StatusOr statusor = ShapeInference::InferVariadicOpShape( + HloOpcode::kSort, {&keys, &values_good, &values_bad}); + EXPECT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("dimensions must match")) << statusor.status(); } +TEST_F(ShapeInferenceTest, SortManyValues) { + auto keys = ShapeUtil::MakeShape(F32, {4}); + auto values_s32 = ShapeUtil::MakeShape(S32, {4}); + auto values_u32 = ShapeUtil::MakeShape(U32, {4}); + StatusOr statusor = ShapeInference::InferVariadicOpShape( + HloOpcode::kSort, {&keys, &values_s32, &values_u32}); + EXPECT_IS_OK(statusor); + Shape inferred_shape = statusor.ValueOrDie(); + EXPECT_TRUE(ShapeUtil::Compatible( + inferred_shape, + ShapeUtil::MakeTupleShape({keys, values_s32, values_u32}))); +} + class ScatterGatherShapeInferenceTest : public ShapeInferenceTest { protected: const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {}); diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 921a984589bb4fb64058a2a56adfe84fe14af69b..56952e3adae59656605a12fd499162504a2a3379 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -147,7 +147,7 @@ void ScopedShapedBuffer::Deallocate() { // Deallocate all non-null buffers. A buffer may appear in more than one spot // in the shape (eg, a tuple with a repeated element) so keep track of what // has been deallocated. - tensorflow::gtl::FlatSet deallocated_ptrs; + absl::flat_hash_set deallocated_ptrs; for (auto& pair : buffers_) { se::DeviceMemoryBase& memory_base = pair.second; if (!memory_base.is_null() && diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 6fed7c76d04ad5d8236fecd07aa27f1eda221ea7..96f3055c98e0611dfe25517cb490014a6d1f7c76 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -148,7 +148,7 @@ TuplePointsToAnalysis::Run(const HloModule* module) { Status TuplePointsToAnalysis::Analyze() { per_instruction_.clear(); - per_instruction_.resize(module_->NumUniqueInstructionIds()); + per_instruction_.reserve(module_->instruction_count()); logical_buffer_aliases_.clear(); logical_buffer_aliases_.resize( @@ -280,16 +280,6 @@ Status TuplePointsToAnalysis::HandleDomain(HloInstruction* domain) { return Status::OK(); } -Status TuplePointsToAnalysis::HandleSlice(HloInstruction* slice) { - // A kSlice instruction aliases its operand if the backend lowers it to an - // in-place implementation. - if (slice->IsInPlaceSlice()) { - CreateCopiedPointsToSet(slice, slice->operand(0)); - return Status::OK(); - } - return DefaultAction(slice); -} - Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { // RecvDone aliases its input (Recv) tuple element {0} to element {0} of its // output. The other indices ({} and {1}) define their own buffers. @@ -455,15 +445,10 @@ bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex( Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const { if (!InstructionDefinesBufferAtIndex(buffer.instruction(), buffer.index())) { - // kSlice ops that are lowered to an in-place version are expected to not - // define their output buffer. - if (buffer.instruction()->opcode() != HloOpcode::kSlice || - !buffer.instruction()->IsInPlaceSlice()) { - return FailedPrecondition( - "LogicalBuffer %s is ill-defined: instruction %s does not define a " - "buffer at that index", - buffer.ToString(), buffer.instruction()->name()); - } + return FailedPrecondition( + "LogicalBuffer %s is ill-defined: instruction %s does not define a " + "buffer at that index", + buffer.ToString(), buffer.instruction()->name()); } if (buffer.id() < 0 || @@ -771,6 +756,7 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( } } if (user->opcode() == HloOpcode::kDynamicUpdateSlice || + user->opcode() == HloOpcode::kScatter || user->opcode() == HloOpcode::kWhile) { // We eliminated other users in BufferLiveness::live_range_strictly_before, // so here we just need to check that the use is at operand index 0. diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index a9e8a51e0923362162c6b8a2e97fc334e56d4329..bcfcb388f95b0bedb35a8c399e804034816867b3 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -36,8 +37,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/compactptrset.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -249,7 +248,6 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleDomain(HloInstruction* domain) override; - Status HandleSlice(HloInstruction* slice) override; Status HandleCopy(HloInstruction* copy) override; Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; @@ -318,14 +316,23 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { const PerInstruction* PerInst(const HloInstruction* inst) const { int id = inst->unique_id(); DCHECK_GE(id, 0); - DCHECK_LT(id, per_instruction_.size()); - return &per_instruction_[id]; + auto iter = per_instruction_.find(id); + if (iter == per_instruction_.end()) { + LOG(FATAL) << "Expected per-instruction information to already exist"; + } else { + return iter->second.get(); + } } PerInstruction* PerInst(const HloInstruction* inst) { int id = inst->unique_id(); DCHECK_GE(id, 0); - DCHECK_LT(id, per_instruction_.size()); - return &per_instruction_[id]; + auto iter = per_instruction_.find(id); + if (iter == per_instruction_.end()) { + return per_instruction_.emplace(id, absl::make_unique()) + .first->second.get(); + } else { + return iter->second.get(); + } } std::vector> GetAllUsesOfInstructionAtIndex( @@ -342,7 +349,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { const std::unique_ptr logical_buffer_analysis_; // A map from instruction->unique_id() to - std::vector per_instruction_; + absl::flat_hash_map> per_instruction_; // A map from LogicalBuffer->id() to alias information about that logical // buffer diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index e9a07b14ed685fa4388aca583395370a60176cca..d9ebebf74ed846aa05326a4df72019ef3e71ad88 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -1010,6 +1010,44 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { points_to_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); } +TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) { + const char* hlo_text = R"( + HloModule TensorFlowScatterV1 + + update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) + } + + ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text)); + computation_ = module_->entry_computation(); + RunAnalysis(); + + HloInstruction* operand_param = computation_->parameter_instruction(0); + HloInstruction* indices_param = computation_->parameter_instruction(1); + HloInstruction* updates_param = computation_->parameter_instruction(2); + HloInstruction* scatter = computation_->root_instruction(); + + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser( + operand_param, {}, scatter, {})); + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser( + indices_param, {}, scatter, {})); + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser( + updates_param, {}, scatter, {})); +} + TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { auto builder = HloComputation::Builder(TestName()); @@ -1035,7 +1073,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { auto values = builder.AddInstruction( HloInstruction::CreateParameter(1, values_shape, "values")); auto sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, + {values})); BuildModuleAndRunAnalysis(builder.Build()); diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc index 56145822be70f391ac3eaab5fc17db4a80e1b9cc..067cfcc17d65860a249de4d9e31703df12091d3a 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc @@ -18,7 +18,6 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index e8fe33e62659ae0fffff1ad46e8ba77f715b76b2..9795b2830b6d9add82b89ac76b5438ddc3d2bfe8 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -15,18 +15,18 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/tuple_util.h" #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { +using absl::flat_hash_map; +using absl::flat_hash_set; using absl::InlinedVector; -using tensorflow::gtl::FlatMap; -using tensorflow::gtl::FlatSet; // Copies `to_hoist` to the computation containing `while_instr`, hoisting its // operands as needed. All of its transitive operands are expected to be either @@ -34,8 +34,8 @@ using tensorflow::gtl::FlatSet; // function hoists the operands in `unhoisted_invariant_instructions` and moves // them into `hoisted_instructions`. static void CreateLoopInvariantCopy( - FlatMap* hoisted_instructions, - FlatSet* unhoisted_invariant_instructions, + flat_hash_map* hoisted_instructions, + flat_hash_set* unhoisted_invariant_instructions, HloInstruction* while_instr, HloInstruction* to_hoist) { HloComputation* parent_of_while = while_instr->parent(); HloComputation* while_body = while_instr->while_body(); @@ -147,13 +147,13 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( // Maps instructions in the while body to instructions hoisted outside the // while that compute the same value. - FlatMap hoisted_instructions; + flat_hash_map hoisted_instructions; // Contains instructions that can be legally hoisted, but were deemed to be // unprofitable to be hoisted alone by NotWorthHoistingIndividually. When we // hoist an instruction in this set, we move it from // unhoisted_invariant_instructions to hoisted_instructions. - FlatSet unhoisted_invariant_instructions; + flat_hash_set unhoisted_invariant_instructions; // Invariant GTE's axiomatically satisfy the constraints for // unhoisted_invariant_instructions -- they can be legally hoisted, but there diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 9a74f22395099fe4f14cbc9af49814d35203df01..630d71e5ca25e9d282ce6283284a32d6f725a193 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -14,12 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/while_loop_analysis.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -114,7 +115,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { return false; } - tensorflow::gtl::FlatSet used_tuple_indices; + absl::flat_hash_set used_tuple_indices; for (HloComputation* comp : {while_body, while_cond}) { // The HLO verifier ensures that while_input's shape matches while_init's // shape, which we verified above is a tuple. @@ -181,7 +182,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { used_tuple_indices.end()); std::sort(new_to_old_tuple_idx.begin(), new_to_old_tuple_idx.end()); - tensorflow::gtl::FlatMap old_to_new_tuple_idx; + absl::flat_hash_map old_to_new_tuple_idx; for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) { int64 old_idx = new_to_old_tuple_idx[new_idx]; old_to_new_tuple_idx[old_idx] = new_idx; @@ -405,7 +406,7 @@ static StatusOr TryPropagateConstant(HloInstruction* while_op) { // build a map from the tuple element index to the constant value. Limit this // to scalar constant values because propagating array constants can regress // performance by forcing us to copy constants. - tensorflow::gtl::FlatMap index_to_constant; + absl::flat_hash_map index_to_constant; for (int i = 0; i < root_operands.size(); i++) { HloInstruction* instr = root_operands[i]; if (instr->opcode() == HloOpcode::kGetTupleElement && diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 020c167ee953bbb3508bae94107de60f386602c0..f55508f8e6d2fd025344bc8289d2403608e3ee25 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -461,8 +461,9 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0; } -/* static */ bool ShapeUtil::IsScalarF32(const Shape& shape) { - return shape.element_type() == F32 && Rank(shape) == 0; +/* static */ bool ShapeUtil::IsScalarWithElementType( + const Shape& shape, PrimitiveType element_type) { + return IsScalar(shape) && shape.element_type() == element_type; } namespace { @@ -596,7 +597,8 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { // we convert in to the RE2-consumable type and then consume the corresponding // amount from our string_view type. static LazyRE2 shape_pattern = { - "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,]+)})?"}; + "^(\\w*\\d*)\\[([\\d,\\s]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,\\s]+)})" + "?"}; tensorflow::RegexpStringPiece s_consumable(s->data(), s->size()); if (RE2::Consume(&s_consumable, *shape_pattern, &element_type_string, &dimensions_string, &format_string, &layout_string)) { @@ -831,7 +833,8 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( const Shape& shape) { - if (shape.element_type() == PRIMITIVE_TYPE_INVALID) { + if (shape.element_type() == PRIMITIVE_TYPE_INVALID || + !PrimitiveType_IsValid(shape.element_type())) { return InvalidArgument("shape has invalid element type: %s", shape.ShortDebugString()); } @@ -868,11 +871,8 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { return Status::OK(); } - if (Rank(shape) != shape.dimensions_size()) { - return InvalidArgument( - "shape's rank is mismatched with dimension count; rank=%d " - "dimensions_size=%d", - Rank(shape), shape.dimensions_size()); + if (LayoutUtil::IsSparseArray(shape) && Rank(shape) == 0) { + return InvalidArgument("sparse arrays must have rank > 0"); } for (int64 i = 0; i < Rank(shape); ++i) { int64 dimension = shape.dimensions(i); @@ -931,7 +931,12 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { return dense_shape_size; } - for (int64 dim : shape.dimensions()) { + bool is_padded = + LayoutUtil::IsDenseArray(shape) && LayoutUtil::IsPadded(shape); + absl::Span shape_max_dimensions = + is_padded ? LayoutUtil::PaddedDimensions(shape) + : AsInt64Slice(shape.dimensions()); + for (int64 dim : shape_max_dimensions) { dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, dim); if (dense_shape_size < 0) { return dense_shape_size; @@ -953,11 +958,10 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayout( const Shape& shape) { - if (LayoutUtil::HasLayout(shape)) { - // Since a layout is present, upgrade to the full set of invariant checks. - return ValidateShape(shape); - } - return ValidateShapeWithOptionalLayoutInternal(shape); + TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape)); + + return LayoutUtil::ValidateLayoutInShape(shape, + /*allow_missing_layouts=*/true); } /* static */ Status ShapeUtil::ValidateShape(const Shape& shape) { @@ -1647,7 +1651,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, } std::ostream& operator<<(std::ostream& out, const Shape& shape) { - out << ShapeUtil::HumanString(shape); + out << ShapeUtil::HumanStringWithLayout(shape); return out; } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index d8bb27beae64bb665c79c2cd7134f613495529cc..51cedce7f0e13e65dfd0e250689e0ecd30f971dc 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -72,7 +72,7 @@ class ShapeIndex { void push_back(int64 value) { indices_.push_back(value); } void pop_back() { indices_.pop_back(); } - // push_front is O(n^2), but shapes don't usually have a ton of dimensions. + // push_front is O(n), but shapes don't usually have a ton of dimensions. void push_front(int64 value) { indices_.insert(indices_.begin(), value); } using container_type = absl::InlinedVector; @@ -312,7 +312,10 @@ class ShapeUtil { static bool IsEffectiveScalar(const Shape& shape) { return IsArray(shape) && TrueRank(shape) == 0; } - static bool IsScalarF32(const Shape& shape); + + // Returns whether "shape" is a scalar (array) with the given element_type. + static bool IsScalarWithElementType(const Shape& shape, + PrimitiveType element_type); // Extracts the size of the shape's dimension at dimension number // GetDimensionNumber(dimension_number). diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index f474ecb18c75327edec449433c36a91d8ac7de83..8a0ae330420531b833ed670118e6b6b1056bd358 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -422,6 +422,7 @@ xla_test( "//tensorflow/core:regexp_internal", "//tensorflow/core:test", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", ], ) @@ -2145,11 +2146,11 @@ xla_test( ":test_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/container:flat_hash_set", ], ) diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.cc b/tensorflow/compiler/xla/tests/codegen_test_base.cc index 022641394f113ef28e7c53058385d77572822213..fbebe0408730f2fb37aa57a0f19291bbaa3826f9 100644 --- a/tensorflow/compiler/xla/tests/codegen_test_base.cc +++ b/tensorflow/compiler/xla/tests/codegen_test_base.cc @@ -32,11 +32,10 @@ StatusOr> CodegenTestBase::CompileToAotCompilationResult( std::unique_ptr hlo_module, const AotCompilationOptions& options) { - std::vector> hlo_modules; - hlo_modules.push_back(std::move(hlo_module)); + auto module_group = absl::make_unique(std::move(hlo_module)); TF_ASSIGN_OR_RETURN( std::vector> results, - backend().compiler()->CompileAheadOfTime(std::move(hlo_modules), + backend().compiler()->CompileAheadOfTime(std::move(module_group), options)); return std::move(results.front()); } diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 070b092d18930027e215cb43ff917e36cac99f12..3aebf784664dac14ba2ea45c5a229b7b2e4fc39d 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -91,7 +91,14 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest { XlaBuilder builder(TestName()); auto lhs = ConstantR4FromArray4D(&builder, *alhs); auto rhs = ConstantR4FromArray4D(&builder, *arhs); - Conv(lhs, rhs, {1, 1}, Padding::kValid); + PrecisionConfig precision; + // The left hand side of the convolution is numbers between 0 and 2304 which + // requires at least 11 mantissa bits and the DEFAULT precision config is + // allowed to round to bfloat16 which only has 7 mantissa bits. + precision.add_operand_precision(PrecisionConfig::HIGHEST); + precision.add_operand_precision(PrecisionConfig::DEFAULT); + Conv(lhs, rhs, {1, 1}, Padding::kValid, /*feature_group_count=*/1, + &precision); ComputeAndCompare(&builder, {}, error_spec_); } @@ -876,7 +883,7 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { // (We run this test on all platforms, because, what the heck.) XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) { execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( - "cudnn-convolution-algorithm-picker"); + "cudnn-conv-algorithm-picker"); XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index a693fa35954bcb2d95074c94d0aa3eabc1d5fd62..001490c6a8c568656437465054ee4db40d0d8dee 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -105,8 +105,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { LiteralTestUtil::ExpectR0Near(10.0f, result, error_spec_); } -XLA_TEST_F(CustomCallTest, - DISABLED_ON_GPU(CustomCall_UsedInOtherComputations)) { +XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) { auto module = CreateNewModule(); auto b = HloComputation::Builder(TestName()); @@ -130,6 +129,53 @@ XLA_TEST_F(CustomCallTest, Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result); } +XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) { + auto module = CreateNewModule(); + auto b = HloComputation::Builder(TestName()); + + auto input = + b.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p")); + b.AddInstruction( + HloInstruction::CreateCustomCall(r2f32_, {input}, "Add1ToValues")); + + module->AddEntryComputation(b.Build()); + ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0})); + ForceResultLayout(module.get(), LayoutUtil::MakeLayout({0, 1})); + + Literal argument = LiteralUtil::CreateR2({{1.f, 2.f}, {3.f, 4.f}}); + + // Note, the expected result is transposed! This is because the input and + // output layouts of the custom call differ and the called function just + // blindly adds one to each element. + Literal result = ExecuteAndTransfer(std::move(module), {&argument}); + LiteralTestUtil::ExpectR2Equal({{2.f, 4.f}, {3.f, 5.f}}, result); +} + +XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) { + // The argument and result of the computation are set to different layouts, + // but the custom call is layout constrained to a fixed operand and result + // layout, so the correct result should be produced. + auto module = CreateNewModule(); + auto b = HloComputation::Builder(TestName()); + + auto input = + b.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p")); + + const Shape& r2f32_dim0_major = + ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); + b.AddInstruction(HloInstruction::CreateCustomCall( + r2f32_dim0_major, {input}, "Add1ToValues", {r2f32_dim0_major})); + + module->AddEntryComputation(b.Build()); + ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0})); + ForceResultLayout(module.get(), LayoutUtil::MakeLayout({0, 1})); + + Literal argument = LiteralUtil::CreateR2({{1.f, 2.f}, {3.f, 4.f}}); + + Literal result = ExecuteAndTransfer(std::move(module), {&argument}); + LiteralTestUtil::ExpectR2Equal({{2.f, 3.f}, {4.f, 5.f}}, result); +} + class CustomCallClientAPITest : public ClientLibraryTestBase {}; // When using the client API, CustomCall targets can't begin with '$' -- these diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 0171f515839d556827f0723772214d175939d386..6c0847a875798870b4362a99ac2ab65d99f9f3e6 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -394,6 +394,10 @@ class ParametricDotTestWithoutLayoutAssignment : public ParametricDotTest { ParametricDotTestWithoutLayoutAssignment() { execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( "layout-assignment"); + // Disable algebraic simplification because the pass may replace a dot + // instruction with a layout-changing multiplication instruction. + execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( + "algsimp"); } }; @@ -404,31 +408,18 @@ std::vector CreateNoLayoutAssignmentDotTestParameters() { for (bool lhs_row_major : {true, false}) { for (bool rhs_row_major : {true, false}) { for (bool has_addend : {true, false}) { + // The addend needs to be row major to match the result of the dot. params.push_back({/*m=*/1, /*k=*/k, /*n=*/n, /*dot_lhs_row_major=*/lhs_row_major, /*dot_rhs_row_major=*/rhs_row_major, /*has_addend=*/has_addend, /*addend_row_major=*/true}); - if (has_addend) { - params.push_back({/*m=*/1, /*k=*/k, /*n=*/n, - /*dot_lhs_row_major=*/lhs_row_major, - /*dot_rhs_row_major=*/rhs_row_major, - /*has_addend=*/has_addend, - /*addend_row_major=*/false}); - } if (n != 1) { params.push_back({/*m=*/n, /*k=*/k, /*n=*/1, /*dot_lhs_row_major=*/lhs_row_major, /*dot_rhs_row_major=*/rhs_row_major, /*has_addend=*/has_addend, /*addend_row_major=*/true}); - if (has_addend) { - params.push_back({/*m=*/n, /*k=*/k, /*n=*/1, - /*dot_lhs_row_major=*/lhs_row_major, - /*dot_rhs_row_major=*/rhs_row_major, - /*has_addend=*/has_addend, - /*addend_row_major=*/false}); - } } } } diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 9c94acb437e9fc948a4255f7112e2e7a40cfa5fb..4d4b676a538947c8dd92a7e34db72e45766cae2c 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -764,8 +764,10 @@ XLA_TEST_F(FusionTest, Clamp2D) { TestElementwise2D(HloOpcode::kClamp); } -// TODO(b/73903144): Enable on interpreter once interpreter supports bitcast. -XLA_TEST_F(FusionTest, DISABLED_ON_INTERPRETER(FusionWithLayout)) { +// TODO(b/117156505): Remove this test when the bug is fixed and the CPU backend +// should not generate layout changing elementwise operations. +#ifdef XLA_TEST_BACKEND_CPU +XLA_TEST_F(FusionTest, LayoutChangingElementWiseOp) { const string hlo_text = R"( HloModule Cluster @@ -794,6 +796,7 @@ ENTRY main { LiteralUtil::CreateR3({{{0.}, {0.76159415595}}, {{0.}, {0.}}}), result)); } +#endif class FusionClientLibraryTest : public ClientLibraryTestBase {}; diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index bdd4fd7e3d0f585d81e94a3326e6d24bb5c42f39..7ab2ecda58666acd7e9b8587d200a902b75822f3 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -86,19 +86,25 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) { } // namespace HloTestBase::HloTestBase(bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier) + bool allow_mixed_precision_in_hlo_verifier, + std::function + instruction_can_change_layout_func) : HloTestBase(GetTestPlatform(), GetReferencePlatform(), verifier_layout_sensitive, - allow_mixed_precision_in_hlo_verifier) {} + allow_mixed_precision_in_hlo_verifier, + instruction_can_change_layout_func) {} HloTestBase::HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier) + bool allow_mixed_precision_in_hlo_verifier, + std::function + instruction_can_change_layout_func) : test_runner_(test_platform), reference_runner_(reference_platform) { hlo_verifier_ = absl::make_unique( /*layout_sensitive=*/verifier_layout_sensitive, - /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier); + /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier, + instruction_can_change_layout_func); } std::unique_ptr HloTestBase::CreateNewModule(const string& name) { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 0ae4bdc104d656946d45008adec9ea3960984545..217428befa474448cf2dcbae2eb6cb5b0e61d44c 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -88,14 +88,18 @@ class HloTestBase : public ::testing::Test { // interpreter is the only supported backend, it will be both the test backend // and the reference backend. HloTestBase(bool verifier_layout_sensitive = false, - bool allow_mixed_precision_in_hlo_verifier = true); + bool allow_mixed_precision_in_hlo_verifier = true, + std::function + instruction_can_change_layout_func = {}); // If your test doesn't use interpreter as the reference backend, you can use // this constructor. Note that your test target is responsible for linking in // both needed backends. HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, bool verifier_layout_sensitive = false, - bool allow_mixed_precision_in_hlo_verifier = true); + bool allow_mixed_precision_in_hlo_verifier = true, + std::function + instruction_can_change_layout_func = {}); ~HloTestBase() override {} diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index 8d658695576035cdc34a213847460dd80de5f67e..c622b295094e53e63d0ed692d428bc97724c787c 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -93,15 +93,16 @@ class LLVMCompilerTest : public ::testing::Test { std::unique_ptr hlo_module = CreateNewModule(); hlo_module->AddEntryComputation(builder.Build()); - std::vector> modules; - modules.push_back(hlo_module->Clone()); - modules.push_back(std::move(hlo_module)); + auto module_group = absl::make_unique("test_module_group"); + module_group->push_back(hlo_module->Clone()); + module_group->push_back(std::move(hlo_module)); std::vector> executors; executors.push_back({backend_->default_stream_executor()}); executors.push_back({backend_->default_stream_executor()}); - EXPECT_IS_OK(compiler->Compile(std::move(modules), std::move(executors), + EXPECT_IS_OK(compiler->Compile(std::move(module_group), + std::move(executors), /*device_allocator=*/nullptr)); } @@ -150,12 +151,12 @@ TEST_F(GpuCompilerTest, HooksTest) { TestCompilerHooks(&compiler); } -TEST_F(CpuCompilerTest, MultiModuleCompilation) { +TEST_F(CpuCompilerTest, CpuMultiModuleCompilation) { cpu::CpuCompiler compiler; TestMultiModuleCompilation(&compiler); } -TEST_F(GpuCompilerTest, MultModuleCompilation) { +TEST_F(GpuCompilerTest, NVPTXMultiModuleCompilation) { gpu::NVPTXCompiler compiler; TestMultiModuleCompilation(&compiler); } diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index 26e2bfde5cdc19657640f24f31bc008d09ad7106..193e66969259f3a8dc18f959c5e72baee11dce24 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -283,7 +283,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionSkippedAfterFusion)) { XlaBuilder builder(TestName()); - Literal a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001, 1.00001}); std::unique_ptr a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie(); auto a = Parameter(&builder, 0, a_literal.shape(), "a"); @@ -301,7 +301,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS, 5, 10, [](const HloOpcode opcode) { return opcode == HloOpcode::kAbs; }); - ComputeAndCompareR1(&builder, {-1.00001f}, {a_data.get()}); + ComputeAndCompareR1(&builder, {-1.00001f, -1.00001f}, {a_data.get()}); } // The interpreter has no fusion pass, so skip this test. @@ -309,7 +309,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionAddedAfterFusion)) { XlaBuilder builder(TestName()); - Literal a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001, 1.00001}); std::unique_ptr a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie(); auto a = Parameter(&builder, 0, a_literal.shape(), "a"); @@ -325,7 +325,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS, 5, 10, [](const HloOpcode opcode) { return opcode == HloOpcode::kFusion; }); - ComputeAndCompareR1(&builder, {-1.0f}, {a_data.get()}); + ComputeAndCompareR1(&builder, {-1.0f, -1.0f}, {a_data.get()}); } // The interpreter has no fusion pass, so skip this test. @@ -358,7 +358,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionAddedFusionContains)) { XlaBuilder builder(TestName()); - Literal a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001, 1.00001}); std::unique_ptr a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie(); auto a = Parameter(&builder, 0, a_literal.shape(), "a"); @@ -375,7 +375,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT, 5, 10, [](const HloOpcode opcode) { return opcode == HloOpcode::kAbs; }); - ComputeAndCompareR1(&builder, {-1.0f}, {a_data.get()}); + ComputeAndCompareR1(&builder, {-1.0f, -1.0f}, {a_data.get()}); } } // namespace diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index c25ccafaf83cf1b29095a77eefa357d9af08dc60..22fe4a2670e2e0e1fedc45036a1ceec19f44e42e 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -638,6 +638,8 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, /*computation=*/computation, /*window_dimensions=*/param.window_bounds, /*window_strides=*/param.strides, + /*base_dilations=*/{}, + /*window_dilations=*/{}, /*padding=*/padding); CHECK(reducer == kAdd || reducer == kMax); @@ -1158,7 +1160,10 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, /*init_value=*/init_value, /*computation=*/computation, /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/padding); + /*window_strides=*/param.strides, + /*base_dilations=*/{}, + /*window_dilations=*/{}, + /*padding=*/padding); auto reduce_func = param.reducer == kAdd ? +[](float a, float b) { return a + b; } @@ -1369,7 +1374,10 @@ TEST_P(R1ReduceWindowTest, DoIt) { /*init_value=*/init_value, /*computation=*/computation, /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/padding); + /*window_strides=*/param.strides, + /*base_dilations=*/{}, + /*window_dilations=*/{}, + /*padding=*/padding); auto reduce_func = param.reducer == kAdd ? +[](float a, float b) { return a + b; } diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc index b21dd56045e1dc11847e213852dea60cd033be7b..7e1f4aa0eb4801876d9bdbac6a4d7f1d09f81ba8 100644 --- a/tensorflow/compiler/xla/tests/scatter_test.cc +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -69,6 +69,37 @@ ENTRY main { RunTest(hlo_text, &operand, &scatter_indices, &updates); } +XLA_TEST_F(ScatterTest, TensorFlowScatterV1_WithFusedAdds) { + const string hlo_text = R"( +HloModule TensorFlowScatterV1 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + p0 = s32[3,3] parameter(0) + operand = s32[3,3] add(p0, p0) + p1 = s32[2] parameter(1) + indices = s32[2] add(p1, p1) + p2 = s32[2,3] parameter(2) + updates = s32[2,3] add(p2, p2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 1}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + XLA_TEST_F(ScatterTest, TensorFlowScatterV2_Update) { const char* hlo_text = R"( HloModule TensorFlowScatterV2 @@ -98,6 +129,37 @@ ENTRY main { RunTest(hlo_text, &operand, &scatter_indices, &updates); } +XLA_TEST_F(ScatterTest, SimpleR4) { + const char* hlo_text = R"( +HloModule SimpleR4 + +add_f32 (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(f32[] lhs, f32[] rhs) +} + +ENTRY main { + operand = f32[1,2,2,1] parameter(0) + indices = s32[1,3] parameter(1) + updates = f32[1,2,2,1] parameter(2) + ROOT scatter = f32[1,2,2,1] scatter(operand, indices, updates), + to_apply=add_f32, + update_window_dims={1,2,3}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0, 2, 1}, + index_vector_dim=1 +} +)"; + + Literal operand = + LiteralUtil::CreateR4({{{{0.f}, {0.f}}, {{0.f}, {0.f}}}}); + Literal updates = + LiteralUtil::CreateR4({{{{0.12}, {0.28}}, {{0.018}, {0.42}}}}); + Literal scatter_indices = LiteralUtil::CreateR2({{0, 0, 0}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + XLA_TEST_F(ScatterTest, TensorFlowScatter_Add) { const string hlo_text = R"( HloModule TensorFlowScatter_Add diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 5155f0c652c7c6dbba60c421159494fa28072090..2f18036ff4c5b0bfa28723fb181c33fa6995eb80 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -272,9 +272,11 @@ std::vector FindConstrainedUses( constrained_uses.insert(constrained_uses.end(), converted_uses.begin(), converted_uses.end()); } else if (opcode == HloOpcode::kSort && - instruction->operand_count() == 2 && op_num == 0) { + instruction->operand_count() >= 2 && op_num == 0) { // Operand 0 of sort is the array of keys used for key/value - // (two-operand) kSort instructions. + // (two-operand) kSort instructions. Since sort stability is not + // guaranteed, constrain keys of key-value sort not to have duplicates, + // since otherwise the value order may legitimately differ. constrained_uses.push_back(instruction); } } diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index 181e5cbe290b0df0cf605cc4ef4b8a4945b3d367..bc433eac8fcb02087d8e4eb10f638c85dc141b22 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -145,7 +146,7 @@ ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> ( ASSERT_EQ(args.size(), 2); const Literal& key_arg = args[0]; - tensorflow::gtl::FlatSet key_set; + absl::flat_hash_set key_set; for (const float& value : key_arg.data()) { EXPECT_TRUE(key_set.insert(tensorflow::bit_cast(value)).second); } @@ -168,7 +169,7 @@ ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> ( ASSERT_EQ(args.size(), 2); const Literal& key_arg = args[0]; - tensorflow::gtl::FlatSet key_set; + absl::flat_hash_set key_set; for (const int32& value : key_arg.data()) { EXPECT_TRUE(key_set.insert(tensorflow::bit_cast(value)).second); } diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 8b1b9e151992296b9d022ae1d9d974eadd2074a8..6d5f276e82087cedc356691b0ff08df24cec8d20 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -48,7 +48,7 @@ class WhileTest : public ClientLibraryTestBase {}; // while (result < 5) { // result = result + 1; // } -TEST_F(WhileTest, WhileWithScalarS32Result) { +XLA_TEST_F(WhileTest, WhileWithScalarS32Result) { auto result_shape = ShapeUtil::MakeShape(S32, {}); // Create a computation for the condition: repeat for 5 iterations. @@ -84,7 +84,7 @@ TEST_F(WhileTest, WhileWithScalarS32Result) { // while (result < 5) { // result = result + 1; // } -TEST_F(WhileTest, WhileWithScalarS64Result) { +XLA_TEST_F(WhileTest, WhileWithScalarS64Result) { auto result_shape = ShapeUtil::MakeShape(S64, {}); // Create a computation for the condition: repeat for 5 iterations. @@ -114,7 +114,7 @@ TEST_F(WhileTest, WhileWithScalarS64Result) { ComputeAndCompareR0(&builder, 5, {}); } -TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { +XLA_TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { auto result_shape = ShapeUtil::MakeShape(S32, {}); auto orig_shape = ShapeUtil::MakeShape(S32, {2}); @@ -147,7 +147,7 @@ TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { ComputeAndCompareR0(&builder, 5, {}); } -TEST_F(WhileTest, WhileWithPredicateResult) { +XLA_TEST_F(WhileTest, WhileWithPredicateResult) { auto result_shape = ShapeUtil::MakeShape(PRED, {}); // Create a computation for the condition: run until condition is true. @@ -184,7 +184,7 @@ TEST_F(WhileTest, WhileWithPredicateResult) { // while (result.sum() < 15.5f) { // result = result + vector(0); // } -TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithEmptyVectorResult)) { +XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithEmptyVectorResult)) { Shape result_shape = ShapeUtil::MakeShape(F32, {0}); // Create a computation for the reduction. @@ -238,7 +238,7 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithEmptyVectorResult)) { // while (result.sum() < 15.5f) { // result = result + vector(8, 0.125f); // } -TEST_F(WhileTest, WhileWithVectorResult) { +XLA_TEST_F(WhileTest, WhileWithVectorResult) { Shape result_shape = ShapeUtil::MakeShape(F32, {8}); // Create a computation for the reduction. @@ -298,7 +298,7 @@ TEST_F(WhileTest, WhileWithVectorResult) { // result = result + vector(8, 0.125f); // } // tuple = tuple { while } -TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { +XLA_TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { Shape result_shape = ShapeUtil::MakeShape(F32, {8}); // Create a computation for the reduction. @@ -353,7 +353,7 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } -TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { +XLA_TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { std::vector shape_elements = { ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; @@ -407,7 +407,7 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } -TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { +XLA_TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { std::vector shape_elements = { ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; @@ -465,7 +465,7 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { // get<0>(result) = get<0>(result) + 1; // get<1>(result) = get<1>(result) + vector(10, 1.0f); // } -TEST_F(WhileTest, WhileWithTupleResult) { +XLA_TEST_F(WhileTest, WhileWithTupleResult) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); @@ -514,7 +514,7 @@ TEST_F(WhileTest, WhileWithTupleResult) { ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } -TEST_F(WhileTest, WhileWithPredicateTupleResult) { +XLA_TEST_F(WhileTest, WhileWithPredicateTupleResult) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(PRED, {})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); @@ -560,7 +560,7 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0)); } -TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { +XLA_TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); @@ -619,7 +619,7 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { // get<1>(w1) = get<1>(w1) + vector(10, 1.0f); // } // result = get<1>(w0) + get<1>(w1) -TEST_F(WhileTest, TwoWhileWithTupleResult) { +XLA_TEST_F(WhileTest, TwoWhileWithTupleResult) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); @@ -698,7 +698,7 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { } // Test while nodes that share the while body computation. -TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { +XLA_TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); @@ -763,7 +763,7 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } -TEST_F(WhileTest, WhileLoopsWithSharedBodyAndInit) { +XLA_TEST_F(WhileTest, WhileLoopsWithSharedBodyAndInit) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); @@ -901,7 +901,7 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Per backend the values generated can be different as the different backends // use different random number generators. // TODO(b/32240857): Extend test to verify outputs. -TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { +XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { auto v6s32 = ShapeUtil::MakeShape(S32, {6}); // Create a computation for the condition: repeat for count iterations. @@ -947,7 +947,7 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { } } -TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { +XLA_TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); XlaBuilder outer("outer"); @@ -979,7 +979,7 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { ErrorSpec(1e-6)); } -TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { +XLA_TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); XlaBuilder outer("outer"); @@ -1004,7 +1004,7 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { ErrorSpec(1e-6)); } -TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { +XLA_TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { auto element_shape = ShapeUtil::MakeShape(F32, {}); XlaBuilder outer("outer"); @@ -1038,7 +1038,7 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { // result[0] = result[0] + 1; // result[1] = result[1] + 1; // } -TEST_F(WhileTest, WhileWithMixedTupleElements) { +XLA_TEST_F(WhileTest, WhileWithMixedTupleElements) { auto result_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}); @@ -1146,7 +1146,7 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { // while (f(result).get<0>()) { // result = result + 1; // } -TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { +XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { auto result_shape = ShapeUtil::MakeShape(S32, {}); // Create a computation for the condition: repeat for 5 iterations. @@ -1186,7 +1186,7 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { ComputeAndCompareR0(&builder, 5, {}); } -TEST_F(WhileTest, WhileWithLoopInvariantOperation) { +XLA_TEST_F(WhileTest, WhileWithLoopInvariantOperation) { auto matrix_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); auto while_shape = ShapeUtil::MakeTupleShape( @@ -1230,7 +1230,7 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) { {param_value.get()}, ErrorSpec(4e-5)); } -TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) { +XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) { auto while_shape = ShapeUtil::MakeShape(S32, {}); XlaComputation condition; diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index db5a824de08edeb81b5deb047507dc6158833008..a6e70eb6ca25ffac24a8ebaf0420238e109e4fad 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -83,7 +83,7 @@ struct ParsedProfileOutputLine { Status ParseOneProfileOutputLine( const string& line, bool expect_hlo, - gtl::FlatMap* parsed_results, + absl::flat_hash_map* parsed_results, absl::Span opcodes_to_ignore = {}) { string separator = "[^:]*:: +"; string match_percentage = R"(\d+\.\d*% +\d+Σ)"; @@ -208,7 +208,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { std::vector profile_output_lines = absl::StrSplit(profile_output, '\n'); - gtl::FlatMap parsed_profile_lines; + absl::flat_hash_map parsed_profile_lines; TF_ASSERT_OK(ParseOneProfileOutputLine( profile_output_lines[1], /*expect_hlo=*/false, &parsed_profile_lines)); @@ -314,7 +314,7 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) { ASSERT_NE(while_body_profile_end, profile_output_lines.end()); - gtl::FlatMap parsed_profile_lines; + absl::flat_hash_map parsed_profile_lines; for (auto while_body_profile_i = while_body_profile_start + 1; while_body_profile_i != while_body_profile_end; while_body_profile_i++) { diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 0c41f227b31ebe1f01073785ea2a666093aefdb3..f910e980535c073562473978662f73f4ee4bee79 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -83,7 +83,8 @@ std::unique_ptr CompileExecutable(const HloSnapshot& module, LocalClient* client) { XlaComputation computation(module.hlo().hlo_module()); std::vector argument_layouts; - for (const auto& param : computation.proto().program_shape().parameters()) { + for (const auto& param : + computation.proto().host_program_shape().parameters()) { argument_layouts.push_back(¶m); } return client diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc index 1d4f8d97f2ed8b263878b94b365b7fb5b949b1a2..dc62cf7a6b24e373374b458d2e4722e79500fb93 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc @@ -166,10 +166,22 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) { VLOG(1) << "Compiling XLA executable"; return Compile(ctx, computation_proto, program); })); - - Tensor output(DT_INT64, TensorShape({})); - output.scalar()() = uid; - ctx->set_output(0, output); + std::unique_ptr entry; + OP_REQUIRES_OK(ctx, cache->Lookup(uid, &entry)); + + Tensor handle_output(DT_INT64, TensorShape({})); + handle_output.scalar()() = uid; + ctx->set_output(0, handle_output); + + xla::LocalExecutable* executable = entry->get().get_executable(); + xla::ProgramShape program_shape = executable->executable() + ->module() + .config() + .entry_computation_layout() + .ComputeProgramShape(); + Tensor program_shape_output(DT_STRING, TensorShape({1})); + program_shape_output.vec()(0) = program_shape.SerializeAsString(); + ctx->set_output(1, program_shape_output); } XRTCompileOp::~XRTCompileOp() = default; diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index 257b054f16a49f3e14e1d76746c9fe0ba7fa8658..3a1e03280a362f6048075be606865712efaffb77 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -64,14 +64,6 @@ uint32 GetXLARandomSeed() { return counter.fetch_add(2); } -// Looks up the input `key` in the compilation cache. -Status GetComputationCacheEntry( - XRTCompilationCache* cache, int64 key, - std::unique_ptr* entry) { - TF_RETURN_IF_ERROR(cache->Lookup(key, entry)); - return Status::OK(); -} - // Populates `inputs` with the input tensors to the computation. Status GetComputationInputs(OpKernelContext* context, ResourceMgr* rm, bool release_inputs, diff --git a/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc b/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc index 5cfc8711f9f4b4d54016156dd53471cadb34b581..7b3b50c69559f6003a108fdf6a1325dbdbaa80a6 100644 --- a/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc @@ -23,7 +23,12 @@ namespace tensorflow { REGISTER_OP("XRTCompile") .Input("computation: string") .Output("handle: int64") - .SetShapeFn(tensorflow::shape_inference::ScalarShape) + .Output("program_shape: string") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + c->set_output(1, c->UnknownShapeOfRank(1)); + return Status::OK(); + }) .Doc( R"( Reads a computation proto, compiles it, and places it in the global compilation diff --git a/tensorflow/compiler/xrt/ops/xrt_execute_op.cc b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc index fda4c31298ebc8c906418afdb8127492b1c5d3f0..40ec1b0ba9b336f5b6407c79c8d63e31219f9b84 100644 --- a/tensorflow/compiler/xrt/ops/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc @@ -21,7 +21,7 @@ limitations under the License. namespace tensorflow { REGISTER_OP("XRTExecute") - .Attr("Ninputs: int") + .Attr("Ninputs: int >= 0") .Input("computation_handle: int64") .Input("execution_config: string") .Input("input_handles: Ninputs * int64") diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD index b6dcfc4eb96316b5dad95a65b04d0ae69e4485f6..be44a3474acdeb9905c1d21b932fa0dd10b5a212 100644 --- a/tensorflow/compiler/xrt/tests/BUILD +++ b/tensorflow/compiler/xrt/tests/BUILD @@ -29,8 +29,11 @@ cc_library( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xrt:xrt_proto", "//tensorflow/compiler/xrt:xrt_server", "//tensorflow/compiler/xrt/cc:xrt_ops", @@ -49,7 +52,10 @@ tf_cc_test( name = "raw_api_test_cpu", size = "medium", srcs = [], - args = ["--xla_test_device=XLA_CPU"], + args = [ + "--xla_test_device=XLA_CPU", + "--xla_platform=CPU", + ], deps = [ ":raw_api_test_lib", "//tensorflow/compiler/jit:xla_cpu_device", @@ -60,7 +66,10 @@ tf_cuda_cc_test( name = "raw_api_test_gpu", size = "medium", srcs = [], - args = ["--xla_test_device=XLA_GPU"], + args = [ + "--xla_test_device=XLA_GPU", + "--xla_platform=GPU", + ], tags = tf_cuda_tests_tags(), deps = [ ":raw_api_test_lib", diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index 2952feb16a8a60aecf16be87c9b800d314c4af58..ad42148ce398fe5bb4494891bfa42500f904aa3f 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -22,10 +22,13 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h" @@ -43,6 +46,7 @@ namespace tensorflow { namespace { string* xla_test_device_ptr; // initial value set in main() +string* xla_platform_ptr; // initial value set in main() string DeviceFromFlag() { string xla_test_device = *xla_test_device_ptr; @@ -108,6 +112,14 @@ bool CompareLiteralToLiteralProto(const xla::Literal& a, return equal; } +xla::XlaComputation OnePlusTwo() { + xla::XlaBuilder builder("OnePlusTwo"); + auto c0 = xla::ConstantR0(&builder, 1.0f); + auto c1 = xla::ConstantR0(&builder, 2.0f); + xla::Add(c0, c1); + return builder.Build().ValueOrDie(); +} + xla::XlaComputation AddAndScale() { xla::XlaBuilder builder("AddAndScale"); auto p0 = xla::Parameter(&builder, 0, @@ -120,6 +132,16 @@ xla::XlaComputation AddAndScale() { return builder.Build().ValueOrDie(); } +xla::XlaComputation AddS64() { + xla::XlaBuilder builder("AddS64"); + auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::S64, {}), + "P0"); + auto p1 = xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(xla::S64, {}), + "P1"); + xla::Add(p0, p1); + return builder.Build().ValueOrDie(); +} + xla::XlaComputation AddAndTuple() { xla::XlaBuilder builder("AddAndTuple"); auto p0 = xla::Parameter(&builder, 0, @@ -137,6 +159,28 @@ void StoreComputationSnapshot(const xla::XlaComputation& computation, *dst = *snapshot; } +xla::ProgramShape XlaCompiledProgramShape( + const xla::XlaComputation& computation, + const xla::ProgramShape& input_program_shape) { + se::Platform* platform = + xla::PlatformUtil::GetPlatform(*xla_platform_ptr).ValueOrDie(); + xla::LocalClient* client = + xla::ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie(); + xla::ExecutableBuildOptions exec_options; + exec_options.set_result_layout(input_program_shape.result()); + std::vector parameters_shapes; + for (int64 i = 0; i < input_program_shape.parameters_size(); ++i) { + parameters_shapes.push_back(&input_program_shape.parameters(i)); + } + auto local_executable = + client->Compile(computation, parameters_shapes, exec_options) + .ValueOrDie(); + return local_executable->executable() + ->module() + .entry_computation() + ->ComputeProgramShape(); +} + TEST(RawApiTest, ReadAndWriteState) { xrt::XLAAllocation alloc; alloc.set_device_ordinal(0); @@ -330,20 +374,120 @@ TEST(RawApiTest, CompileAndExecute) { auto p1_value = ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()); auto p1_handle = ops::XRTAllocate(root, p1_value); - auto result = ops::XRTExecute(root, c_handle, e_config, + auto result = ops::XRTExecute(root, c_handle.handle, e_config, {Output(p0_handle), Output(p1_handle)}); auto read_back = ops::XRTReadLiteralAndRelease(root, result); TF_ASSERT_OK(root.status()); ClientSession session(root); std::vector outputs; - TF_EXPECT_OK(session.Run({read_back}, &outputs)); + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); xla::LiteralProto response; EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto expected = xla::LiteralUtil::CreateR1({27.0f, 21.0f}); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); + + xla::ProgramShape program_shape; + EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); + EXPECT_EQ(program_shape.parameters_size(), 2); +} + +TEST(RawApiTest, CompileWithXlaReturnShapes) { + xla::XlaBuilder builder("XrtXlaShapes"); + auto input_shape = xla::ShapeUtil::MakeShape(xla::BF16, {32, 3, 128, 128}); + auto kernel_shape = xla::ShapeUtil::MakeShape(xla::BF16, {3, 3, 5, 5}); + // Clear layouts to signal XLA we are ready to get whatever are coming out of + // the compilation process. + xla::LayoutUtil::ClearLayout(&input_shape); + xla::LayoutUtil::ClearLayout(&kernel_shape); + auto param_shape = + xla::ShapeUtil::MakeTupleShape({input_shape, kernel_shape}); + auto param = xla::Parameter(&builder, 0, param_shape, "param"); + auto input = xla::GetTupleElement(param, 0); + auto kernel = xla::GetTupleElement(param, 1); + xla::Conv(input, kernel, {1, 1}, xla::Padding::kSame); + TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation xla_computation, builder.Build()); + + auto result_shape = xla_computation.GetProgramShape().ValueOrDie().result(); + // Clear the result shape layout to tell XLA we are accepting whatever are + // coming out of the compilation process. + xla::LayoutUtil::ClearLayout(&result_shape); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = param_shape; + *shapes->mutable_result() = result_shape; + StoreComputationSnapshot(xla_computation, c.mutable_hlo_snapshot()); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto computation = + ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto release = ops::XRTReleaseCompilationHandle(root, c_handle.handle); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), + {c_handle.program_shape}, {release}, &outputs)); + + xla::ProgramShape program_shape; + EXPECT_TRUE(program_shape.ParseFromString(outputs[0].vec()(0))); + EXPECT_EQ(program_shape.parameters_size(), 1); + + VLOG(2) << "Param: " + << xla::ShapeUtil::HumanStringWithLayout(program_shape.parameters(0)); + VLOG(2) << "Result: " + << xla::ShapeUtil::HumanStringWithLayout(program_shape.result()); + + xla::ProgramShape xla_program_shape = + XlaCompiledProgramShape(xla_computation, *shapes); + EXPECT_TRUE(xla::LayoutUtil::Equal( + xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(), + xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0}) + .layout())); + EXPECT_TRUE(xla::LayoutUtil::Equal( + xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {1}).layout(), + xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {1}) + .layout())); + EXPECT_TRUE(xla::LayoutUtil::Equal(program_shape.result().layout(), + xla_program_shape.result().layout())); +} + +TEST(RawApiTest, CompileAndExecuteZeroArg) { + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {}); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + StoreComputationSnapshot(OnePlusTwo(), c.mutable_hlo_snapshot()); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto e_config = + ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); + auto computation = + ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto result = ops::XRTExecute(root, c_handle.handle, e_config, + std::initializer_list({})); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + + auto expected = xla::LiteralUtil::CreateR0(3.0f); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); } TEST(RawApiTest, CompileAndExecuteReturnTuple) { @@ -379,7 +523,7 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) { auto p1_value = ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()); auto p1_handle = ops::XRTAllocate(root, p1_value); - auto result = ops::XRTExecute(root, c_handle, e_config, + auto result = ops::XRTExecute(root, c_handle.handle, e_config, {Output(p0_handle), Output(p1_handle)}); auto read_back = ops::XRTReadLiteralAndRelease(root, result); TF_ASSERT_OK(root.status()); @@ -396,15 +540,93 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) { EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); } +TEST(RawApiTest, LeakCompilationReference) { + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2}); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2}); + *shapes->mutable_result() = xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::F32, {2})}); + StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot()); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto computation = + ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({c_handle.handle}, &outputs)); +} + +TEST(RawApiTest, CompileAndExecuteWithS64Argument) { + xrt::XLAAllocation p0; + p0.set_device_ordinal(0); + *p0.mutable_value() = xla::LiteralUtil::CreateR0(11031965).ToProto(); + xrt::XLAAllocation p1; + p1.set_device_ordinal(0); + *p1.mutable_value() = xla::LiteralUtil::CreateR0(4091934).ToProto(); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}); + *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::S64, {}); + StoreComputationSnapshot(AddS64(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto e_config = + ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); + auto computation = + ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = + ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto p1_value = + ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()); + auto p1_handle = ops::XRTAllocate(root, p1_value); + auto result = ops::XRTExecute(root, c_handle.handle, e_config, + {Output(p0_handle), Output(p1_handle)}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + + auto expected = xla::LiteralUtil::CreateR0(15123899); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); + + xla::ProgramShape program_shape; + EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); + EXPECT_EQ(program_shape.parameters_size(), 2); + EXPECT_TRUE( + xla::ShapeUtil::HasPrimitiveType(program_shape.result(), xla::S64)); +} + } // namespace } // namespace tensorflow int main(int argc, char** argv) { tensorflow::xla_test_device_ptr = new tensorflow::string("XLA_CPU"); + tensorflow::xla_platform_ptr = new tensorflow::string("CPU"); std::vector flag_list = { tensorflow::Flag("xla_test_device", tensorflow::xla_test_device_ptr, "Tensorflow device type to use for test, e.g., XLA_CPU"), + tensorflow::Flag("xla_platform", tensorflow::xla_platform_ptr, + "The XLA platform to select for the device"), }; tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.cc b/tensorflow/compiler/xrt/xrt_compilation_cache.cc index 4844c7fb7106862dd42b3b3d07245350c9d2383c..d1405eae468492748ae88d842334a922dce272c6 100644 --- a/tensorflow/compiler/xrt/xrt_compilation_cache.cc +++ b/tensorflow/compiler/xrt/xrt_compilation_cache.cc @@ -18,9 +18,19 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/random/random.h" namespace tensorflow { +namespace { + +int64 get_uid() { + uint64 unsigned_rand = random::New64() & INT64_MAX; + return static_cast(unsigned_rand); +} + +} // namespace + const char* kXRTCompilationCacheResourceName = "xrt_compilation_cache"; XRTCompilationCache::EntryRefImpl::EntryRefImpl(XRTCompilationCache* parent, @@ -46,12 +56,17 @@ XRTCompilationCache::XRTCompilationCache(int max_number_of_entries) XRTCompilationCache::~XRTCompilationCache() { VLOG(1) << "XRTCompilationCache::~XRTCompilationCache()"; + // A buggy client may be holding onto a reference, or a client might have + // crashed while holding onto a reference. In either case, discard all + // outstanding client references to avoid leaking storage. + for (const auto& entry : entries_by_uid_) { + while (!entry.second->RefCountIsOne()) { + entry.second->Unref(); + } + } while (!entries_by_last_use_.empty()) { MarkOldestEntryForEviction(); } - // By the time the cache is deleted all reference holders should have already - // been deleted, since they were holding references to the cache. So all - // entries should be gone at this point. CHECK_EQ(cache_.size(), 0); CHECK_EQ(entries_by_uid_.size(), 0); CHECK_EQ(cache_entries_, 0); @@ -148,7 +163,7 @@ XRTCompilationCache::CompiledSubgraph* XRTCompilationCache::InitializeEntry( CompiledSubgraph* entry = new CompiledSubgraph(); entry->parent = this; entry->key = key; - entry->uid = next_uid_++; + entry->uid = get_uid(); // Add the entry to the cache. Once the computation has been compiled, // UpdateEntryAfterCompilation will be called to potentially mark old entries // that don't fit any more for eviction. diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.h b/tensorflow/compiler/xrt/xrt_compilation_cache.h index c505299a454506e2136e36fb26833c28ed0d47bc..c43d0fc47873abdc82ee937c155bebc346a05f17 100644 --- a/tensorflow/compiler/xrt/xrt_compilation_cache.h +++ b/tensorflow/compiler/xrt/xrt_compilation_cache.h @@ -211,8 +211,6 @@ class XRTCompilationCache : public ResourceBase { const int max_cache_entries_; mutable absl::Mutex mu_; - // The uid to assign to the next new entry created. - int64 next_uid_ GUARDED_BY(mu_) = 0; // The total number of entries that are stored and not marked for eviction. int cache_entries_ GUARDED_BY(mu_) = 0; // The total number of entries that are marked for eviction. diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index d05a1e7dcbff440e0daf03bd25535c26d82b6a0b..3a99820d7aa9e9546cc95385fd98c05f28988e9e 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/stream_executor/stream_executor.h" @@ -42,12 +43,9 @@ namespace { const char* kTupleContainer = "tuples"; -// Counter used to assign unique handles. -mutex _uid_mutex(tensorflow::LINKER_INITIALIZED); -int64 _uid GUARDED_BY(_uid_mutex) = 0; int64 get_uid() { - mutex_lock l(_uid_mutex); - return _uid++; + uint64 unsigned_rand = random::New64() & INT64_MAX; + return static_cast(unsigned_rand); } Status AllocateScopedShapedBuffer( @@ -67,6 +65,9 @@ Status AllocateScopedShapedBuffer( // requests the host-shape sub-buffer at index i, that will correspond to the // right device-shape sub-buffer at the same index. xla::Shape on_device_shape = transfer_manager->HostShapeToDeviceShape(shape); + VLOG(3) << "Allocating literal buffer: host_shape=" + << xla::ShapeUtil::HumanStringWithLayout(shape) << " device_shape=" + << xla::ShapeUtil::HumanStringWithLayout(on_device_shape); // The ScopedShapedBuffer frees the buffers that have so far been allocated if // it goes out of scope. That's useful if we return early as the result of an diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 98dff965a94cdd2138ddbe3a160e20b0b0cb3197..78ad19a4ab112be08569c857c8ed4e16ceed6d80 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -29,6 +29,7 @@ py_library( "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/coder:coder_py", "//tensorflow/contrib/compiler:compiler_py", + "//tensorflow/contrib/compiler:xla", "//tensorflow/contrib/autograph", "//tensorflow/contrib/constrained_optimization", "//tensorflow/contrib/copy_graph:copy_graph_py", @@ -112,17 +113,52 @@ py_library( "//tensorflow/python:util", "//tensorflow/python/estimator:estimator_py", ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + select({ + "//tensorflow:android": [], + "//tensorflow:ios": [], "//tensorflow:linux_s390x": [], "//tensorflow:windows": [], + "//tensorflow:no_kafka_support": [], "//conditions:default": [ - "//tensorflow/contrib/bigtable", - "//tensorflow/contrib/cloud:cloud_py", - "//tensorflow/contrib/fused_conv:fused_conv_py", # unresolved symbols, need to export more symbols "//tensorflow/contrib/kafka", + ], + }) + select({ + "//tensorflow:android": [], + "//tensorflow:ios": [], + "//tensorflow:linux_s390x": [], + "//tensorflow:windows": [], + "//tensorflow:no_aws_support": [], + "//conditions:default": [ "//tensorflow/contrib/kinesis", + ], + }) + select({ + "//tensorflow:android": [], + "//tensorflow:ios": [], + "//tensorflow:linux_s390x": [], + "//tensorflow:windows": [], + "//conditions:default": [ + "//tensorflow/contrib/fused_conv:fused_conv_py", "//tensorflow/contrib/tensorrt:init_py", "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", ], + }) + select({ + "//tensorflow:android": [], + "//tensorflow:ios": [], + "//tensorflow:linux_s390x": [], + "//tensorflow:windows": [], + "//tensorflow:no_gcp_support": [], + "//conditions:default": [ + "//tensorflow/contrib/bigtable", + "//tensorflow/contrib/cloud:cloud_py", + ], + }) + select({ + "//tensorflow:android": [], + "//tensorflow:ios": [], + "//tensorflow:linux_s390x": [], + "//tensorflow:windows": [], + "//tensorflow:no_ignite_support": [], + "//conditions:default": [ + "//tensorflow/contrib/ignite", + ], }), ) @@ -146,14 +182,26 @@ cc_library( ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_cuda([ "//tensorflow/contrib/nccl:nccl_kernels", ]) + select({ + "//tensorflow:android": [], + "//tensorflow:ios": [], "//tensorflow:linux_s390x": [], "//tensorflow:windows": [], + "//tensorflow:no_kafka_support": [], "//conditions:default": [ "//tensorflow/contrib/kafka:dataset_kernels", + ], + }) + select({ + "//tensorflow:android": [], + "//tensorflow:ios": [], + "//tensorflow:linux_s390x": [], + "//tensorflow:windows": [], + "//tensorflow:no_aws_support": [], + "//conditions:default": [ "//tensorflow/contrib/kinesis:dataset_kernels", - "//tensorflow/contrib/tensorrt:trt_engine_op_kernel", ], - }), + }) + if_not_windows([ + "//tensorflow/contrib/tensorrt:trt_engine_op_kernel", + ]), ) cc_library( @@ -177,12 +225,33 @@ cc_library( "//tensorflow/contrib/text:all_ops", "//tensorflow/contrib/tpu:all_ops", ] + select({ + "//tensorflow:android": [], + "//tensorflow:ios": [], "//tensorflow:linux_s390x": [], "//tensorflow:windows": [], + "//tensorflow:no_kafka_support": [], "//conditions:default": [ "//tensorflow/contrib/kafka:dataset_ops_op_lib", + ], + }) + select({ + "//tensorflow:android": [], + "//tensorflow:ios": [], + "//tensorflow:linux_s390x": [], + "//tensorflow:windows": [], + "//tensorflow:no_aws_support": [], + "//conditions:default": [ "//tensorflow/contrib/kinesis:dataset_ops_op_lib", - "//tensorflow/contrib/tensorrt:trt_engine_op_op_lib", + ], + }) + if_not_windows([ + "//tensorflow/contrib/tensorrt:trt_engine_op_op_lib", + ]) + select({ + "//tensorflow:android": [], + "//tensorflow:ios": [], + "//tensorflow:linux_s390x": [], + "//tensorflow:windows": [], + "//tensorflow:no_ignite_support": [], + "//conditions:default": [ + "//tensorflow/contrib/ignite:dataset_ops_op_lib", ], }), ) diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index e71b0e0ae33f9c2dd48643e557447372bc67b3e3..f52a1a7babceeae93cdd2e5a93dad413a1d30191 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -21,14 +21,6 @@ from __future__ import print_function import os -from tensorflow.python.tools import component_api_helper -component_api_helper.package_hook( - parent_package_str=( - "tensorflow.contrib"), - child_package_str=( - "tensorflow_estimator.contrib.estimator")) -del component_api_helper - # Add projects here, they will show up under tf.contrib. from tensorflow.contrib import autograph from tensorflow.contrib import batching diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index b27a19b16c08cb588b45949105a6399623e766e1..648f3ebb05646a66144bcb118347cbc391909409 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -7,64 +7,6 @@ package( licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -cc_library( - name = "batch_scheduler_hdrs", - hdrs = ["batch_scheduler.h"], - deps = [ - "//tensorflow/core/kernels/batching_util:batch_scheduler_hdrs", - ], -) - -cc_library( - name = "batch_scheduler", - hdrs = ["batch_scheduler.h"], - deps = [ - "//tensorflow/core/kernels/batching_util:batch_scheduler", - ], -) - -cc_library( - name = "shared_batch_scheduler_hdrs", - hdrs = ["shared_batch_scheduler.h"], - deps = [ - "//tensorflow/core/kernels/batching_util:shared_batch_scheduler_hdrs", - ], -) - -cc_library( - name = "shared_batch_scheduler", - hdrs = ["shared_batch_scheduler.h"], - deps = [ - "//tensorflow/core/kernels/batching_util:shared_batch_scheduler", - ], - alwayslink = 1, -) - -cc_library( - name = "adaptive_shared_batch_scheduler", - hdrs = ["adaptive_shared_batch_scheduler.h"], - deps = [ - "//tensorflow/core/kernels/batching_util:adaptive_shared_batch_scheduler", - ], -) - -cc_library( - name = "serial_device_batch_scheduler", - hdrs = ["serial_device_batch_scheduler.h"], - deps = [ - "//tensorflow/core/kernels/batching_util:serial_device_batch_scheduler", - ], -) - -cc_library( - name = "basic_batch_scheduler", - hdrs = ["basic_batch_scheduler.h"], - deps = [ - "//tensorflow/core/kernels/batching_util:basic_batch_scheduler", - ], -) - load( "//tensorflow:tensorflow.bzl", "py_test", diff --git a/tensorflow/contrib/batching/serial_device_batch_scheduler.h b/tensorflow/contrib/batching/serial_device_batch_scheduler.h deleted file mode 100644 index bf6b7083612018eecf0d1784e60cbbf0c5796fef..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/batching/serial_device_batch_scheduler.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_ -#define TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_ - -#include "tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h" - -#endif // TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_ diff --git a/tensorflow/contrib/batching/test_util/BUILD b/tensorflow/contrib/batching/test_util/BUILD deleted file mode 100644 index 7cb2d8079bd18660f72eab92654629434ce4d6a5..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/batching/test_util/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -# Description: Utilities to aid testing. - -package( - default_visibility = ["//tensorflow:internal"], -) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -cc_library( - name = "fake_clock_env", - testonly = 1, - hdrs = ["fake_clock_env.h"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core/kernels/batching_util:fake_clock_env", - ], -) diff --git a/tensorflow/contrib/batching/test_util/fake_clock_env.h b/tensorflow/contrib/batching/test_util/fake_clock_env.h deleted file mode 100644 index 40a39a5569854350c72a47102f3dac07b362ce8e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/batching/test_util/fake_clock_env.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_ -#define TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_ - -#include "tensorflow/core/kernels/batching_util/fake_clock_env.h" - -#endif // TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_ diff --git a/tensorflow/contrib/batching/util/BUILD b/tensorflow/contrib/batching/util/BUILD deleted file mode 100644 index 8f81b6702f2807d7da7e72190ce2d86b28e52113..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/batching/util/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -# Description: Utilities. - -package( - default_visibility = ["//tensorflow:internal"], -) - -licenses(["notice"]) # Apache 2.0 - -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -cc_library( - name = "periodic_function_dynamic", - hdrs = ["periodic_function.h"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core/kernels/batching_util:periodic_function_dynamic", - "//third_party/eigen3", - ], -) - -cc_library( - name = "periodic_function", - visibility = ["//visibility:public"], - deps = [ - ":periodic_function_dynamic", - "//tensorflow/core/kernels/batching_util:periodic_function", - ], -) diff --git a/tensorflow/contrib/batching/util/periodic_function.h b/tensorflow/contrib/batching/util/periodic_function.h deleted file mode 100644 index aa2ed0a385125fa090a7a56b6339a87eb2d57b1f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/batching/util/periodic_function.h +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_ -#define TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_ - -#include "tensorflow/core/kernels/batching_util/periodic_function.h" - -#endif // TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_ diff --git a/tensorflow/contrib/bayesflow/__init__.py b/tensorflow/contrib/bayesflow/__init__.py index 41a8c920fc4e81af90f4c94a149d8c404c58b747..493046b39907971e92f91ecc60d375ea273ff1d2 100644 --- a/tensorflow/contrib/bayesflow/__init__.py +++ b/tensorflow/contrib/bayesflow/__init__.py @@ -14,6 +14,8 @@ # ============================================================================== """Ops for representing Bayesian computation. +Use [tfp](/probability/api_docs/python/tfp) instead. + ## This package provides classes for Bayesian computation with TensorFlow. """ from __future__ import absolute_import diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py index 68fa415eeaf1d1ae7c2ecf1be1c300eddbfa4e69..28a829d87ddecc4a147c588b5b0536b44db8393f 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Monte Carlo integration and helpers.""" +"""Monte Carlo integration and helpers. + +Use [tfp.monte_carlo](/probability/api_docs/python/tfp/monte_carlo) instead. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/bigtable/README.md b/tensorflow/contrib/bigtable/README.md index f33eaf7e3df356e10939f591ef75cb4f17978144..2c44abed5e1955cc666273e97e6b2378766f13d2 100644 --- a/tensorflow/contrib/bigtable/README.md +++ b/tensorflow/contrib/bigtable/README.md @@ -203,7 +203,7 @@ def interleave_fn(index): start = tf.string_join(['training_data_', start_idx_str]) end = tf.string_join(['training_data_', end_idx_str]) return table.scan_range(start_idx, end_idx, columns=columns) -ds = ds.apply(tf.contrib.data.parallel_interleave( +ds = ds.apply(tf.data.experimental.parallel_interleave( interleave_fn, cycle_length=NUM_PARALLEL_READS, prefetch_input_elements=1)) ``` @@ -249,7 +249,7 @@ def make_row_key_dataset(): - ... - fake-data-23498103 """ - counter_dataset = tf.contrib.data.Counter() + counter_dataset = tf.data.experimental.Counter() width = 8 row_key_prefix = 'fake-data-' ds = counter_dataset.map(lambda index: tf.as_string(index, diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py index cf56822ff401ac81c4bb7bd65418c81fe7b398d8..7c87b0daeb09950cc44c51f49c16534d413f0376 100644 --- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py +++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py @@ -31,8 +31,8 @@ from six import iteritems from six import string_types from tensorflow.contrib.bigtable.ops import gen_bigtable_ops -from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.contrib.util import loader +from tensorflow.python.data.experimental.ops import interleave_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes @@ -228,7 +228,7 @@ class BigtableTable(object): """Retrieves a sampling of row keys from the Bigtable table. This dataset is most often used in conjunction with - `tf.contrib.data.parallel_interleave` to construct a set of ranges for + `tf.data.experimental.parallel_interleave` to construct a set of ranges for scanning in parallel. Returns: diff --git a/tensorflow/contrib/boosted_trees/BUILD b/tensorflow/contrib/boosted_trees/BUILD index f03eab510c2f9010fc92eb1934ac77dc0626a44b..f7f15a302a00ee4187d57fc4d40727b84e6c587c 100644 --- a/tensorflow/contrib/boosted_trees/BUILD +++ b/tensorflow/contrib/boosted_trees/BUILD @@ -98,7 +98,6 @@ py_library( "//tensorflow/contrib/boosted_trees/proto:learner_proto_py", "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/learn", - "//tensorflow/contrib/stateless", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", @@ -108,6 +107,7 @@ py_library( "//tensorflow/python:math_ops", "//tensorflow/python:platform", "//tensorflow/python:sparse_tensor", + "//tensorflow/python:stateless_random_ops", "//tensorflow/python:summary", "//tensorflow/python:tensor_shape", "//tensorflow/python:training", diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index 48f12a64f94c7bd0531488ef537b199558e17e3e..a3df272e6924792128fc38fd153b9527b58b486e 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -196,6 +196,10 @@ def convert_to_universal_format(dtec, sorted_feature_names, matching_id = categorical_test.value.add() matching_id.int64_value = split.feature_id node.custom_left_child_test.Pack(categorical_test) + elif (node_type == "oblivious_dense_float_binary_split" or + node_type == "oblivious_categorical_id_binary_split"): + raise ValueError("Universal tree format doesn't support oblivious " + "trees") else: raise ValueError("Unexpected node type %s" % node_type) node.left_child_id.value = split.left_id @@ -229,6 +233,13 @@ def _get_feature_importances(dtec, feature_names, num_dense_floats, split = tree_node.categorical_id_binary_split split_column = feature_names[split.feature_column + num_dense_floats + num_sparse_float] + elif node_type == "oblivious_dense_float_binary_split": + split = tree_node.oblivious_dense_float_binary_split + split_column = feature_names[split.feature_column] + elif node_type == "oblivious_categorical_id_binary_split": + split = tree_node.oblivious_categorical_id_binary_split + split_column = feature_names[split.feature_column + num_dense_floats + + num_sparse_float] elif node_type == "categorical_id_set_membership_binary_split": split = tree_node.categorical_id_set_membership_binary_split split_column = feature_names[split.feature_column + num_dense_floats + diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py index 194a5c8754cb0ab2db299e3fb5c998c0f27f8435..ca73e4af2fbd0a383d02fa7111f59161701661df 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -28,7 +28,6 @@ import six from tensorflow.contrib import layers from tensorflow.contrib.boosted_trees.estimator_batch import model from tensorflow.contrib.boosted_trees.estimator_batch import distillation_loss -from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks from tensorflow.contrib.boosted_trees.python.ops import model_ops from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch @@ -170,6 +169,7 @@ def _dnn_tree_combined_model_fn( if (output_type == model.ModelBuilderOutputType.ESTIMATOR_SPEC and not use_core_versions): raise ValueError("You must use core versions with Estimator Spec") + global_step = training_util.get_global_step() with variable_scope.variable_scope( dnn_parent_scope, @@ -191,46 +191,58 @@ def _dnn_tree_combined_model_fn( feature_columns=dnn_feature_columns, weight_collections=[dnn_parent_scope], scope=input_layer_scope) - previous_layer = input_layer - for layer_id, num_hidden_units in enumerate(dnn_hidden_units): + def dnn_logits_fn(): + """Builds the logits from the input layer.""" + previous_layer = input_layer + for layer_id, num_hidden_units in enumerate(dnn_hidden_units): + with variable_scope.variable_scope( + "hiddenlayer_%d" % layer_id, + values=(previous_layer,)) as hidden_layer_scope: + net = layers.fully_connected( + previous_layer, + num_hidden_units, + activation_fn=dnn_activation_fn, + variables_collections=[dnn_parent_scope], + scope=hidden_layer_scope) + if dnn_dropout is not None and mode == model_fn.ModeKeys.TRAIN: + net = layers.dropout(net, keep_prob=(1.0 - dnn_dropout)) + _add_hidden_layer_summary(net, hidden_layer_scope.name) + previous_layer = net with variable_scope.variable_scope( - "hiddenlayer_%d" % layer_id, - values=(previous_layer,)) as hidden_layer_scope: - net = layers.fully_connected( + "logits", values=(previous_layer,)) as logits_scope: + dnn_logits = layers.fully_connected( previous_layer, - num_hidden_units, - activation_fn=dnn_activation_fn, + head.logits_dimension, + activation_fn=None, variables_collections=[dnn_parent_scope], - scope=hidden_layer_scope) - if dnn_dropout is not None and mode == model_fn.ModeKeys.TRAIN: - net = layers.dropout(net, keep_prob=(1.0 - dnn_dropout)) - _add_hidden_layer_summary(net, hidden_layer_scope.name) - previous_layer = net - with variable_scope.variable_scope( - "logits", values=(previous_layer,)) as logits_scope: - dnn_logits = layers.fully_connected( - previous_layer, - head.logits_dimension, - activation_fn=None, - variables_collections=[dnn_parent_scope], - scope=logits_scope) - _add_hidden_layer_summary(dnn_logits, logits_scope.name) - - def _dnn_train_op_fn(loss): - """Returns the op to optimize the loss.""" - return optimizers.optimize_loss( - loss=loss, - global_step=training_util.get_global_step(), - learning_rate=_DNN_LEARNING_RATE, - optimizer=_get_optimizer(dnn_optimizer), - name=dnn_parent_scope, - variables=ops.get_collection( - ops.GraphKeys.TRAINABLE_VARIABLES, scope=dnn_parent_scope), - # Empty summaries to prevent optimizers from logging training_loss. - summaries=[]) + scope=logits_scope) + _add_hidden_layer_summary(dnn_logits, logits_scope.name) + return dnn_logits + if predict_with_tree_only and mode == model_fn.ModeKeys.INFER: + dnn_logits = array_ops.constant(0.0) + dnn_train_op_fn = control_flow_ops.no_op + elif predict_with_tree_only and mode == model_fn.ModeKeys.EVAL: + dnn_logits = control_flow_ops.cond( + global_step > dnn_steps_to_train, + lambda: array_ops.constant(0.0), + dnn_logits_fn) + dnn_train_op_fn = control_flow_ops.no_op + else: + dnn_logits = dnn_logits_fn() + def dnn_train_op_fn(loss): + """Returns the op to optimize the loss.""" + return optimizers.optimize_loss( + loss=loss, + global_step=training_util.get_global_step(), + learning_rate=_DNN_LEARNING_RATE, + optimizer=_get_optimizer(dnn_optimizer), + name=dnn_parent_scope, + variables=ops.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES, scope=dnn_parent_scope), + # Empty summaries to prevent optimizers from logging training_loss. + summaries=[]) # Build Tree Logits. - global_step = training_util.get_global_step() with ops.device(global_step.device): ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, @@ -261,8 +273,13 @@ def _dnn_tree_combined_model_fn( """Returns the op to optimize the loss.""" if dnn_to_tree_distillation_param: loss_weight, loss_fn = dnn_to_tree_distillation_param - weight_tensor = head_lib._weight_tensor( # pylint: disable=protected-access - features, head.weight_column_name) + # pylint: disable=protected-access + if use_core_versions: + weight_tensor = head_lib._weight_tensor(features, head._weight_column) + else: + weight_tensor = head_lib._weight_tensor( + features, head.weight_column_name) + # pylint: enable=protected-access dnn_logits_fixed = array_ops.stop_gradient(dnn_logits) if loss_fn is None: @@ -305,52 +322,26 @@ def _dnn_tree_combined_model_fn( finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() if output_type == model.ModelBuilderOutputType.MODEL_FN_OPS: - if use_core_versions: - model_fn_ops = head.create_estimator_spec( - features=features, - mode=mode, - labels=labels, - train_op_fn=_no_train_op_fn, - logits=tree_train_logits) - dnn_train_op = head.create_estimator_spec( - features=features, - mode=mode, - labels=labels, - train_op_fn=_dnn_train_op_fn, - logits=dnn_logits) - dnn_train_op = estimator_utils.estimator_spec_to_model_fn_ops( - dnn_train_op).train_op - - tree_train_op = head.create_estimator_spec( - features=tree_features, - mode=mode, - labels=labels, - train_op_fn=_tree_train_op_fn, - logits=tree_train_logits) - tree_train_op = estimator_utils.estimator_spec_to_model_fn_ops( - tree_train_op).train_op - - model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops( - model_fn_ops) - else: - model_fn_ops = head.create_model_fn_ops( - features=features, - mode=mode, - labels=labels, - train_op_fn=_no_train_op_fn, - logits=tree_train_logits) - dnn_train_op = head.create_model_fn_ops( - features=features, - mode=mode, - labels=labels, - train_op_fn=_dnn_train_op_fn, - logits=dnn_logits).train_op - tree_train_op = head.create_model_fn_ops( - features=tree_features, - mode=mode, - labels=labels, - train_op_fn=_tree_train_op_fn, - logits=tree_train_logits).train_op + model_fn_ops = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_no_train_op_fn, + logits=tree_train_logits) + if mode != model_fn.ModeKeys.TRAIN: + return model_fn_ops + dnn_train_op = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=dnn_train_op_fn, + logits=dnn_logits).train_op + tree_train_op = head.create_model_fn_ops( + features=tree_features, + mode=mode, + labels=labels, + train_op_fn=_tree_train_op_fn, + logits=tree_train_logits).train_op # Add the hooks model_fn_ops.training_hooks.extend([ @@ -369,11 +360,13 @@ def _dnn_tree_combined_model_fn( labels=labels, train_op_fn=_no_train_op_fn, logits=tree_train_logits) + if mode != model_fn.ModeKeys.TRAIN: + return fusion_spec dnn_spec = head.create_estimator_spec( features=features, mode=mode, labels=labels, - train_op_fn=_dnn_train_op_fn, + train_op_fn=dnn_train_op_fn, logits=dnn_logits) tree_spec = head.create_estimator_spec( features=tree_features, diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py index 839eedd3a87ccaa1faecd1966fe5907d682cac02..dea19b7c62649679f944809b44c51ba0cd361904 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py @@ -18,13 +18,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import tempfile from tensorflow.contrib.boosted_trees.estimator_batch import dnn_tree_combined_estimator as estimator from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils from tensorflow.contrib.learn.python.learn.estimators import run_config +from tensorflow.python.estimator import exporter from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.estimator.export import export +from tensorflow.python.ops import parsing_ops from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -34,6 +38,7 @@ from tensorflow.python.ops.losses import losses from tensorflow.python.platform import googletest from tensorflow.python.training import checkpoint_utils + def _train_input_fn(): features = { "x": constant_op.constant([[2.], [1.], [1.]]) @@ -103,35 +108,6 @@ class DNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase): classifier.fit(input_fn=_train_input_fn, steps=15) classifier.evaluate(input_fn=_eval_input_fn, steps=1) - def testFitAndEvaluateDontThrowExceptionWithCore(self): - learner_config = learner_pb2.LearnerConfig() - learner_config.num_classes = 2 - learner_config.constraints.max_tree_depth = 1 - model_dir = tempfile.mkdtemp() - config = run_config.RunConfig() - - # Use core head - head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( - loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) - - classifier = estimator.DNNBoostedTreeCombinedEstimator( - head=head_fn, - dnn_hidden_units=[1], - # Use core feature columns - dnn_feature_columns=[core_feature_column.numeric_column("x")], - tree_learner_config=learner_config, - num_trees=1, - tree_examples_per_layer=3, - model_dir=model_dir, - config=config, - dnn_steps_to_train=10, - dnn_input_layer_to_tree=True, - tree_feature_columns=[], - use_core_versions=True) - - classifier.fit(input_fn=_train_input_fn, steps=15) - classifier.evaluate(input_fn=_eval_input_fn, steps=1) - def testFitAndEvaluateWithDistillation(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 @@ -223,6 +199,51 @@ class CoreDNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase): self.assertLess(0.5, res["auc"]) est.predict(input_fn=_eval_input_fn) + def testTrainEvaluateWithDnnForInputAndTreeForPredict(self): + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + est = estimator.CoreDNNBoostedTreeCombinedEstimator( + head=head_fn, + dnn_hidden_units=[1], + dnn_feature_columns=[core_feature_column.numeric_column("x")], + tree_learner_config=learner_config, + num_trees=1, + tree_examples_per_layer=3, + model_dir=model_dir, + config=config, + dnn_steps_to_train=10, + dnn_input_layer_to_tree=True, + predict_with_tree_only=True, + dnn_to_tree_distillation_param=(0.5, None), + tree_feature_columns=[]) + + # Train for a few steps. + est.train(input_fn=_train_input_fn, steps=1000) + res = est.evaluate(input_fn=_eval_input_fn, steps=1) + self.assertLess(0.5, res["auc"]) + est.predict(input_fn=_eval_input_fn) + serving_input_fn = ( + export.build_parsing_serving_input_receiver_fn( + feature_spec={"x": parsing_ops.FixedLenFeature( + [1], dtype=dtypes.float32)})) + base_exporter = exporter.FinalExporter( + name="Servo", + serving_input_receiver_fn=serving_input_fn, + assets_extra=None) + export_path = os.path.join(model_dir, "export") + base_exporter.export( + est, + export_path=export_path, + checkpoint_path=None, + eval_result={}, + is_the_final_export=True) if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 8531e97f90236b8e5eb64bc0f4c9bb3b674f35cd..bd5d5bb695684c7dcb5cc5c0038386074edd4dc4 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -22,7 +22,6 @@ import collections import copy from tensorflow.contrib import learn -from tensorflow.contrib import stateless from tensorflow.contrib.boosted_trees.lib.learner.batch import categorical_split_handler from tensorflow.contrib.boosted_trees.lib.learner.batch import ordinal_split_handler from tensorflow.contrib.boosted_trees.proto import learner_pb2 @@ -44,6 +43,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops +from tensorflow.python.ops import stateless_random_ops as stateless from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 1056894f18f1ec19a598dfbd1161d7f9bea7e94f..f4a8e16c99f464b813a98e981579bd0ff53bd464 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -60,6 +60,7 @@ class TPUClusterResolver(ClusterResolver): if (self._tpu == compat.as_bytes('') or self._tpu == compat.as_bytes('local') or self._tpu.startswith(compat.as_bytes('/bns')) or + self._tpu.startswith(compat.as_bytes('localhost:')) or self._tpu.startswith(compat.as_bytes('grpc://'))): return False return True diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index c6d6f04168b0c95662123f5feceb5ebb0474ffe9..fbdca497fcc3126d2086d289ebdb113370072d22 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -1,6 +1,16 @@ # Minimum CMake required cmake_minimum_required(VERSION 3.5) +if(WIN32) + if(${CMAKE_VERSION} VERSION_LESS "3.8") + message(WARNING "Your current cmake version is ${CMAKE_VERSION} which does not support setting the toolset architecture to x64. This may cause \"compiler out of heap space\" errors when building. Consider upgrading your cmake to > 3.8 and using the flag -Thost=x64 when running cmake.") + else() + if(NOT CMAKE_VS_PLATFORM_TOOLSET_HOST_ARCHITECTURE OR NOT "${CMAKE_VS_PLATFORM_TOOLSET_HOST_ARCHITECTURE}" STREQUAL "x64") + message(WARNING "Your current cmake generator is set to use 32 bit toolset architecture. This may cause \"compiler out of heap space\" errors when building. Consider using the flag -Thost=x64 when running cmake.") + endif() + endif() +endif() + # Project project(tensorflow C CXX) @@ -30,7 +40,6 @@ endif() option(tensorflow_ENABLE_GRPC_SUPPORT "Enable gRPC support" ON) option(tensorflow_ENABLE_HDFS_SUPPORT "Enable HDFS support" OFF) -option(tensorflow_ENABLE_JEMALLOC_SUPPORT "Enable jemalloc support" OFF) option(tensorflow_BUILD_CC_EXAMPLE "Build the C++ tutorial example" ON) option(tensorflow_BUILD_PYTHON_BINDINGS "Build the Python bindings" ON) option(tensorflow_BUILD_ALL_KERNELS "Build all OpKernels" ON) @@ -90,10 +99,12 @@ if (NOT WIN32) # Options for linking other libraries option(systemlib_ZLIB "Use the system installed library as shared objects instead of downloading ZLIB and statically linking to it: ZLIB" OFF) + option(systemlib_ABSEIL_CPP "Use the system installed library as shared objects instead of downloading ABSEIL_CPP and statically linking to it: ABSEIL_CPP" OFF) option(systemlib_ALL "Turn on every possible systemlib_* options" OFF) if (systemlib_ALL) set (systemlib_ZLIB ON) + set (systemlib_ABSEIL_CPP ON) endif (systemlib_ALL) endif() @@ -115,7 +126,7 @@ function(SHOW_VARIABLES) endfunction() # External dependencies -set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/external) +set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/external ${PROJECT_SOURCE_DIR}/modules) # Location where external projects will be downloaded set (DOWNLOAD_LOCATION "${CMAKE_CURRENT_BINARY_DIR}/downloads" @@ -218,10 +229,6 @@ if (tensorflow_WIN_CPU_SIMD_OPTIONS) endif() endif() -if (tensorflow_ENABLE_JEMALLOC_SUPPORT) - add_definitions(-DTENSORFLOW_USE_JEMALLOC -DJEMALLOC_EXPORT=) -endif() - # External dependencies include(zlib) include(gif) @@ -240,6 +247,7 @@ include(re2) include(cub) include(sqlite) include(double_conversion) +include(abseil_cpp) if (tensorflow_BUILD_CC_TESTS) include(googletest) endif() @@ -248,6 +256,7 @@ add_definitions(${ADD_CFLAGS}) link_directories(${ADD_LINK_DIRECTORY}) set(tensorflow_EXTERNAL_LIBRARIES + ${tensorflow_EXTERNAL_LIBRARIES} ${gif_STATIC_LIBRARIES} ${png_STATIC_LIBRARIES} ${jpeg_STATIC_LIBRARIES} @@ -329,12 +338,6 @@ if(tensorflow_ENABLE_GRPC_SUPPORT) list(APPEND tensorflow_EXTERNAL_DEPENDENCIES boringssl) endif() endif() -if(tensorflow_ENABLE_JEMALLOC_SUPPORT) - include(jemalloc) - list(APPEND tensorflow_EXTERNAL_LIBRARIES ${jemalloc_STATIC_LIBRARIES}) - list(APPEND tensorflow_EXTERNAL_DEPENDENCIES jemalloc) - include_directories(${jemalloc_INCLUDE_DIRS}) -endif() if(tensorflow_ENABLE_SNAPPY_SUPPORT) include(snappy) list(APPEND tensorflow_EXTERNAL_LIBRARIES ${snappy_STATIC_LIBRARIES}) @@ -363,9 +366,7 @@ if (tensorflow_ENABLE_MKL_SUPPORT) list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkldnn_STATIC_LIBRARIES}) list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn_copy_shared_to_destination) include_directories(${mkldnn_INCLUDE_DIRS}) - else (tensorflow_ENABLE_MKLDNN_SUPPORT) - add_definitions(-DINTEL_MKL_ML_ONLY) - endif() + endif(tensorflow_ENABLE_MKLDNN_SUPPORT) endif (tensorflow_ENABLE_MKL_SUPPORT) if (tensorflow_ENABLE_GPU) diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md index 77242b34fd8302cb9104c50a83d4141607911e7f..84c679162c3ed8ffc9babcd3af583b26fb62c2d6 100644 --- a/tensorflow/contrib/cmake/README.md +++ b/tensorflow/contrib/cmake/README.md @@ -108,180 +108,177 @@ ops or APIs. Step-by-step Windows build ========================== -1. Install the prerequisites detailed above, and set up your environment. - - * The following commands assume that you are using the Windows Command - Prompt (`cmd.exe`). You will need to set up your environment to use the - appropriate toolchain, i.e. the 64-bit tools. (Some of the binary targets - we will build are too large for the 32-bit tools, and they will fail with - out-of-memory errors.) The typical command to do set up your - environment is: - - ``` - D:\temp> "C:\Program Files (x86)\Microsoft Visual Studio 14.0\VC\bin\amd64\vcvarsall.bat" - ``` - - * When building with GPU support after installing the CUDNN zip file from NVidia, append its - bin directory to your PATH environment variable. - In case TensorFlow fails to find the CUDA dll's during initialization, check your PATH environment variable. - It should contain the directory of the CUDA dlls and the directory of the CUDNN dll. - For example: - - ``` - D:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\bin - D:\local\cuda\bin - ``` - - * When building with MKL support after installing [MKL](https://software.intel.com/en-us/mkl) from INTEL, append its bin directories to your PATH environment variable. - - In case TensorFlow fails to find the MKL dll's during initialization, check your PATH environment variable. - It should contain the directory of the MKL dlls. For example: - - ``` - D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\mkl - D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\compiler - D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\tbb\vc_mt - ``` - - - * We assume that `cmake` and `git` are installed and in your `%PATH%`. If - for example `cmake` is not in your path and it is installed in - `C:\Program Files (x86)\CMake\bin\cmake.exe`, you can add this directory - to your `%PATH%` as follows: - - ``` - D:\temp> set PATH="%PATH%;C:\Program Files (x86)\CMake\bin\cmake.exe" - ``` - -2. Clone the TensorFlow repository and create a working directory for your - build: - - ``` - D:\temp> git clone https://github.com/tensorflow/tensorflow.git - D:\temp> cd tensorflow\tensorflow\contrib\cmake - D:\temp\tensorflow\tensorflow\contrib\cmake> mkdir build - D:\temp\tensorflow\tensorflow\contrib\cmake> cd build - D:\temp\tensorflow\tensorflow\contrib\cmake\build> - ``` - -3. Invoke CMake to create Visual Studio solution and project files. - - **N.B.** This assumes that `cmake.exe` is in your `%PATH%` environment - variable. The other paths are for illustrative purposes only, and may - be different on your platform. The `^` character is a line continuation - and must be the last character on each line. - - ``` - D:\...\build> cmake .. -A x64 -DCMAKE_BUILD_TYPE=Release ^ - More? -DSWIG_EXECUTABLE=C:/tools/swigwin-3.0.10/swig.exe ^ - More? -DPYTHON_EXECUTABLE=C:/Users/%USERNAME%/AppData/Local/Continuum/Anaconda3/python.exe ^ - More? -DPYTHON_LIBRARIES=C:/Users/%USERNAME%/AppData/Local/Continuum/Anaconda3/libs/python35.lib - ``` - To build with GPU support add "^" at the end of the last line above following with: - ``` - More? -Dtensorflow_ENABLE_GPU=ON ^ - More? -DCUDNN_HOME="D:\...\cudnn" - ``` - To build with MKL support add "^" at the end of the last line above following with: - - ``` - More? -Dtensorflow_ENABLE_MKL_SUPPORT=ON ^ - More? -DMKL_HOME="D:\...\compilers_and_libraries" - ``` - - To enable SIMD instructions with MSVC, as AVX and SSE, define it as follows: - - ``` - More? -Dtensorflow_WIN_CPU_SIMD_OPTIONS=/arch:AVX - ``` - - Note that the `-DCMAKE_BUILD_TYPE=Release` flag must match the build - configuration that you choose when invoking `msbuild`. The known-good - values are `Release` and `RelWithDebInfo`. The `Debug` build type is - not currently supported, because it relies on a `Debug` library for - Python (`python35d.lib`) that is not distributed by default. - - There are various options that can be specified when generating the - solution and project files: - - * `-DCMAKE_BUILD_TYPE=(Release|RelWithDebInfo)`: Note that the - `CMAKE_BUILD_TYPE` option must match the build configuration that you - choose when invoking MSBuild in step 4. The known-good values are - `Release` and `RelWithDebInfo`. The `Debug` build type is not currently - supported, because it relies on a `Debug` library for Python - (`python35d.lib`) that is not distributed by default. - - * `-Dtensorflow_BUILD_ALL_KERNELS=(ON|OFF)`. Defaults to `ON`. You can - build a small subset of the kernels for a faster build by setting this - option to `OFF`. - - * `-Dtensorflow_BUILD_CC_EXAMPLE=(ON|OFF)`. Defaults to `ON`. Generate - project files for a simple C++ - [example training program](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/tutorials/example_trainer.cc). - - * `-Dtensorflow_BUILD_PYTHON_BINDINGS=(ON|OFF)`. Defaults to `ON`. Generate - project files for building a PIP package containing the TensorFlow runtime - and its Python bindings. - - * `-Dtensorflow_ENABLE_GRPC_SUPPORT=(ON|OFF)`. Defaults to `ON`. Include - gRPC support and the distributed client and server code in the TensorFlow - runtime. - - * `-Dtensorflow_ENABLE_SSL_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include - SSL support (for making secure HTTP requests) in the TensorFlow runtime. - This support is incomplete, and will be used for Google Cloud Storage - support. - - * `-Dtensorflow_ENABLE_GPU=(ON|OFF)`. Defaults to `OFF`. Include - GPU support. If GPU is enabled you need to install the CUDA 8.0 Toolkit and CUDNN 5.1. - CMake will expect the location of CUDNN in -DCUDNN_HOME=path_you_unzipped_cudnn. - - * `-Dtensorflow_BUILD_CC_TESTS=(ON|OFF)`. Defaults to `OFF`. This builds cc unit tests. - There are many of them and building will take a few hours. - After cmake, build and execute the tests with - ``` - MSBuild /p:Configuration=RelWithDebInfo ALL_BUILD.vcxproj - ctest -C RelWithDebInfo - ``` - - * `-Dtensorflow_BUILD_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This enables python kernel tests. - After building the python wheel, you need to install the new wheel before running the tests. - To execute the tests, use - ``` - ctest -C RelWithDebInfo - ``` - - * `-Dtensorflow_BUILD_MORE_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This enables python tests on - serveral major packages. This option is only valid if this and tensorflow_BUILD_PYTHON_TESTS are both set as `ON`. - After building the python wheel, you need to install the new wheel before running the tests. - To execute the tests, use - ``` - ctest -C RelWithDebInfo - ``` - - * `-Dtensorflow_ENABLE_MKL_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include MKL support. If MKL is enabled you need to install the [Intel Math Kernal Library](https://software.intel.com/en-us/mkl). - CMake will expect the location of MKL in -MKL_HOME=path_you_install_mkl. - - * `-Dtensorflow_ENABLE_MKLDNN_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include MKL DNN support. MKL DNN is [Intel(R) Math Kernel Library for Deep Neural Networks (Intel(R) MKL-DNN)](https://github.com/intel/mkl-dnn). You have to add `-Dtensorflow_ENABLE_MKL_SUPPORT=ON` before including MKL DNN support. - - -4. Invoke MSBuild to build TensorFlow. - - To build the C++ example program, which will be created as a `.exe` - executable in the subdirectory `.\Release`: - - ``` - D:\...\build> MSBuild /p:Configuration=Release tf_tutorials_example_trainer.vcxproj - D:\...\build> Release\tf_tutorials_example_trainer.exe - ``` - - To build the PIP package, which will be created as a `.whl` file in the - subdirectory `.\tf_python\dist`: - - ``` - D:\...\build> MSBuild /p:Configuration=Release tf_python_build_pip_package.vcxproj - ``` - +1. Install the prerequisites detailed above, and set up your environment. + + * When building with GPU support after installing the CUDNN zip file from + NVidia, append its bin directory to your PATH environment variable. In + case TensorFlow fails to find the CUDA dll's during initialization, + check your PATH environment variable. It should contain the directory of + the CUDA dlls and the directory of the CUDNN dll. For example: + + ``` + D:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\bin + D:\local\cuda\bin + ``` + + * When building with MKL support after installing + [MKL](https://software.intel.com/en-us/mkl) from INTEL, append its bin + directories to your PATH environment variable. + + In case TensorFlow fails to find the MKL dll's during initialization, + check your PATH environment variable. It should contain the directory of + the MKL dlls. For example: + + ``` + D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\mkl + D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\compiler + D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\tbb\vc_mt + ``` + + * We assume that `cmake` and `git` are installed and in your `%PATH%`. If + for example `cmake` is not in your path and it is installed in + `C:\Program Files (x86)\CMake\bin\cmake.exe`, you can add this directory + to your `%PATH%` as follows: + + ``` + D:\temp> set PATH="%PATH%;C:\Program Files (x86)\CMake\bin\cmake.exe" + ``` + +2. Clone the TensorFlow repository and create a working directory for your + build: + + ``` + D:\temp> git clone https://github.com/tensorflow/tensorflow.git + D:\temp> cd tensorflow\tensorflow\contrib\cmake + D:\temp\tensorflow\tensorflow\contrib\cmake> mkdir build + D:\temp\tensorflow\tensorflow\contrib\cmake> cd build + D:\temp\tensorflow\tensorflow\contrib\cmake\build> + ``` + +3. Invoke CMake to create Visual Studio solution and project files. + + **N.B.** This assumes that `cmake.exe` is in your `%PATH%` environment + variable. The other paths are for illustrative purposes only, and may be + different on your platform. The `^` character is a line continuation and + must be the last character on each line. + + ``` + D:\...\build> cmake .. -A x64 -Thost=x64 -DCMAKE_BUILD_TYPE=Release ^ + More? -DSWIG_EXECUTABLE=C:/tools/swigwin-3.0.10/swig.exe ^ + More? -DPYTHON_EXECUTABLE=C:/Users/%USERNAME%/AppData/Local/Continuum/Anaconda3/python.exe ^ + More? -DPYTHON_LIBRARIES=C:/Users/%USERNAME%/AppData/Local/Continuum/Anaconda3/libs/python35.lib + ``` + + To build with GPU support add "^" at the end of the last line above + following with: `More? -Dtensorflow_ENABLE_GPU=ON ^ More? + -DCUDNN_HOME="D:\...\cudnn"` To build with MKL support add "^" at the end of + the last line above following with: + + ``` + More? -Dtensorflow_ENABLE_MKL_SUPPORT=ON ^ + More? -DMKL_HOME="D:\...\compilers_and_libraries" + ``` + + To enable SIMD instructions with MSVC, as AVX and SSE, define it as follows: + + ``` + More? -Dtensorflow_WIN_CPU_SIMD_OPTIONS=/arch:AVX + ``` + + Note that the `-DCMAKE_BUILD_TYPE=Release` flag must match the build + configuration that you choose when invoking `msbuild`. The known-good values + are `Release` and `RelWithDebInfo`. The `Debug` build type is not currently + supported, because it relies on a `Debug` library for Python + (`python35d.lib`) that is not distributed by default. + + The `-Thost=x64` flag will ensure that the 64 bit compiler and linker is + used when building. Without this flag, MSBuild will use the 32 bit toolchain + which is prone to compile errors such as "compiler out of heap space". + + There are various options that can be specified when generating the solution + and project files: + + * `-DCMAKE_BUILD_TYPE=(Release|RelWithDebInfo)`: Note that the + `CMAKE_BUILD_TYPE` option must match the build configuration that you + choose when invoking MSBuild in step 4. The known-good values are + `Release` and `RelWithDebInfo`. The `Debug` build type is not currently + supported, because it relies on a `Debug` library for Python + (`python35d.lib`) that is not distributed by default. + + * `-Dtensorflow_BUILD_ALL_KERNELS=(ON|OFF)`. Defaults to `ON`. You can + build a small subset of the kernels for a faster build by setting this + option to `OFF`. + + * `-Dtensorflow_BUILD_CC_EXAMPLE=(ON|OFF)`. Defaults to `ON`. Generate + project files for a simple C++ + [example training program](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/tutorials/example_trainer.cc). + + * `-Dtensorflow_BUILD_PYTHON_BINDINGS=(ON|OFF)`. Defaults to `ON`. + Generate project files for building a PIP package containing the + TensorFlow runtime and its Python bindings. + + * `-Dtensorflow_ENABLE_GRPC_SUPPORT=(ON|OFF)`. Defaults to `ON`. Include + gRPC support and the distributed client and server code in the + TensorFlow runtime. + + * `-Dtensorflow_ENABLE_SSL_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include + SSL support (for making secure HTTP requests) in the TensorFlow runtime. + This support is incomplete, and will be used for Google Cloud Storage + support. + + * `-Dtensorflow_ENABLE_GPU=(ON|OFF)`. Defaults to `OFF`. Include GPU + support. If GPU is enabled you need to install the CUDA 8.0 Toolkit and + CUDNN 5.1. CMake will expect the location of CUDNN in + -DCUDNN_HOME=path_you_unzipped_cudnn. + + * `-Dtensorflow_BUILD_CC_TESTS=(ON|OFF)`. Defaults to `OFF`. This builds + cc unit tests. There are many of them and building will take a few + hours. After cmake, build and execute the tests with `MSBuild + /p:Configuration=RelWithDebInfo ALL_BUILD.vcxproj ctest -C + RelWithDebInfo` + + * `-Dtensorflow_BUILD_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This + enables python kernel tests. After building the python wheel, you need + to install the new wheel before running the tests. To execute the tests, + use `ctest -C RelWithDebInfo` + + * `-Dtensorflow_BUILD_MORE_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This + enables python tests on serveral major packages. This option is only + valid if this and tensorflow_BUILD_PYTHON_TESTS are both set as `ON`. + After building the python wheel, you need to install the new wheel + before running the tests. To execute the tests, use `ctest -C + RelWithDebInfo` + + * `-Dtensorflow_ENABLE_MKL_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include + MKL support. If MKL is enabled you need to install the + [Intel Math Kernal Library](https://software.intel.com/en-us/mkl). CMake + will expect the location of MKL in -MKL_HOME=path_you_install_mkl. + + * `-Dtensorflow_ENABLE_MKLDNN_SUPPORT=(ON|OFF)`. Defaults to `OFF`. + Include MKL DNN support. MKL DNN is [Intel(R) Math Kernel Library for + Deep Neural Networks (Intel(R) + MKL-DNN)](https://github.com/intel/mkl-dnn). You have to add + `-Dtensorflow_ENABLE_MKL_SUPPORT=ON` before including MKL DNN support. + +4. Invoke MSBuild to build TensorFlow. + + Set up the path to find MSbuild: `D:\temp> "C:\Program Files (x86)\Microsoft + Visual Studio 14.0\VC\bin\amd64\vcvarsall.bat"` + + To build the C++ example program, which will be created as a `.exe` + executable in the subdirectory `.\Release`: + + ``` + D:\...\build> MSBuild /p:Configuration=Release tf_tutorials_example_trainer.vcxproj + D:\...\build> Release\tf_tutorials_example_trainer.exe + ``` + + To build the PIP package, which will be created as a `.whl` file in the + subdirectory `.\tf_python\dist`: + + ``` + D:\...\build> MSBuild /p:Configuration=Release tf_python_build_pip_package.vcxproj + ``` Linux Continuous Integration build ================================== diff --git a/tensorflow/contrib/cmake/external/abseil_cpp.cmake b/tensorflow/contrib/cmake/external/abseil_cpp.cmake new file mode 100644 index 0000000000000000000000000000000000000000..c6c5021f60b38ed05a19f3e439c9810251841f76 --- /dev/null +++ b/tensorflow/contrib/cmake/external/abseil_cpp.cmake @@ -0,0 +1,100 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +if (systemlib_ABSEIL_CPP) + + find_package(AbseilCpp REQUIRED + absl_base + absl_spinlock_wait + absl_dynamic_annotations + absl_malloc_internal + absl_throw_delegate + absl_strings + str_format_internal + absl_bad_optional_access) + + include_directories(${ABSEIL_CPP_INCLUDE_DIR}) + list(APPEND tensorflow_EXTERNAL_LIBRARIES ${ABSEIL_CPP_LIBRARIES}) + + message(STATUS " abseil_cpp includes: ${ABSEIL_CPP_INCLUDE_DIR}") + message(STATUS " abseil_cpp libraries: ${ABSEIL_CPP_LIBRARIES}") + + add_custom_target(abseil_cpp_build) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp_build) + +else (systemlib_ABSEIL_CPP) + + include (ExternalProject) + + set(abseil_cpp_INCLUDE_DIR ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp_build) + set(abseil_cpp_URL https://github.com/abseil/abseil-cpp/archive/e01d95528ea2137a4a27a88d1f57c6cb260aafed.tar.gz) + set(abseil_cpp_HASH SHA256=84043ed402d2a2a6ba4cdddb7e85118b1158fd81fe4ac3a14adc343d054c1e2e) + set(abseil_cpp_BUILD ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp_build) + + if(WIN32) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(abseil_cpp_STATIC_LIBRARIES + ${abseil_cpp_BUILD}/absl/base/Release/absl_base.lib + ${abseil_cpp_BUILD}/absl/base/Release/absl_spinlock_wait.lib + ${abseil_cpp_BUILD}/absl/base/Release/absl_dynamic_annotations.lib + ${abseil_cpp_BUILD}/absl/base/Release/absl_malloc_internal.lib + ${abseil_cpp_BUILD}/absl/base/Release/absl_throw_delegate.lib + ${abseil_cpp_BUILD}/absl/strings/Release/absl_strings.lib + ${abseil_cpp_BUILD}/absl/strings/Release/str_format_internal.lib + ${abseil_cpp_BUILD}/absl/types/Release/absl_bad_optional_access.lib) + else() + set(abseil_cpp_STATIC_LIBRARIES + ${abseil_cpp_BUILD}/absl/base/absl_base.lib + ${abseil_cpp_BUILD}/absl/base/absl_spinlock_wait.lib + ${abseil_cpp_BUILD}/absl/base/absl_dynamic_annotations.lib + ${abseil_cpp_BUILD}/absl/base/absl_malloc_internal.lib + ${abseil_cpp_BUILD}/absl/base/absl_throw_delegate.lib + ${abseil_cpp_BUILD}/absl/strings/absl_strings.lib + ${abseil_cpp_BUILD}/absl/strings/str_format_internal.lib + ${abseil_cpp_BUILD}/absl/types/absl_bad_optional_access.lib) + endif() + else() + set(abseil_cpp_STATIC_LIBRARIES + ${abseil_cpp_BUILD}/absl/base/libabsl_base.a + ${abseil_cpp_BUILD}/absl/base/libabsl_spinlock_wait.a + ${abseil_cpp_BUILD}/absl/base/libabsl_dynamic_annotations.a + ${abseil_cpp_BUILD}/absl/base/libabsl_malloc_internal.a + ${abseil_cpp_BUILD}/absl/base/libabsl_throw_delegate.a + ${abseil_cpp_BUILD}/absl/strings/libabsl_strings.a + ${abseil_cpp_BUILD}/absl/strings/libstr_format_internal.a + ${abseil_cpp_BUILD}/absl/types/libabsl_bad_optional_access.a) + endif() + + ExternalProject_Add(abseil_cpp_build + PREFIX abseil_cpp + URL ${abseil_cpp_URL} + URL_HASH ${abseil_cpp_HASH} + DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" + BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${abseil_cpp_STATIC_LIBRARIES} + BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release + COMMAND ${CMAKE_COMMAND} --build . --config Release + INSTALL_COMMAND "" + CMAKE_CACHE_ARGS + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} + -DCMAKE_BUILD_TYPE:STRING=Release + -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF + ) + + include_directories(${abseil_cpp_INCLUDE_DIR}) + list(APPEND tensorflow_EXTERNAL_LIBRARIES ${abseil_cpp_STATIC_LIBRARIES}) + + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp_build) + +endif (systemlib_ABSEIL_CPP) diff --git a/tensorflow/contrib/cmake/external/jemalloc.cmake b/tensorflow/contrib/cmake/external/jemalloc.cmake deleted file mode 100644 index afadcc007d66414be3306e91e7186a00b6e587ce..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/cmake/external/jemalloc.cmake +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -include (ExternalProject) - -set(jemalloc_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/jemalloc/src/jemalloc/include) -set(jemalloc_URL https://mirror.bazel.build/github.com/jemalloc/jemalloc-cmake/archive/jemalloc-cmake.4.3.1.tar.gz) -set(jemalloc_HASH SHA256=f9be9a05fe906deb5c1c8ca818071a7d2e27d66fd87f5ba9a7bf3750bcedeaf0) -set(jemalloc_BUILD ${CMAKE_CURRENT_BINARY_DIR}/jemalloc/src/jemalloc) - -if (WIN32) - set(jemalloc_INCLUDE_DIRS - ${jemalloc_INCLUDE_DIRS} - ${CMAKE_CURRENT_BINARY_DIR}/jemalloc/src/jemalloc/include/msvc_compat - ) - if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") - set(jemalloc_STATIC_LIBRARIES ${jemalloc_BUILD}/Release/jemalloc.lib) - else() - set(jemalloc_STATIC_LIBRARIES ${jemalloc_BUILD}/jemalloc.lib) - endif() -else() - set(jemalloc_STATIC_LIBRARIES ${jemalloc_BUILD}/Release/jemalloc.a) -endif() - -ExternalProject_Add(jemalloc - PREFIX jemalloc - URL ${jemalloc_URL} - URL_HASH ${jemalloc_HASH} - DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" - BUILD_IN_SOURCE 1 - BUILD_BYPRODUCTS ${jemalloc_STATIC_LIBRARIES} - BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target jemalloc - INSTALL_COMMAND ${CMAKE_COMMAND} -E echo "Skipping install step." - CMAKE_CACHE_ARGS - -DCMAKE_BUILD_TYPE:STRING=Release - -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF - -Dwith-jemalloc-prefix:STRING=jemalloc_ - -Dwithout-export:BOOL=ON -) diff --git a/tensorflow/contrib/cmake/external/protobuf.cmake b/tensorflow/contrib/cmake/external/protobuf.cmake index f56fb35a0f71250f00b84e5cf94a24682bda6c82..56a57a2340ddc7f923c611c222a0399e279ad58a 100644 --- a/tensorflow/contrib/cmake/external/protobuf.cmake +++ b/tensorflow/contrib/cmake/external/protobuf.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(PROTOBUF_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src) set(PROTOBUF_URL https://github.com/google/protobuf.git) -set(PROTOBUF_TAG v3.6.0) +set(PROTOBUF_TAG v3.6.1) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") diff --git a/tensorflow/contrib/cmake/make.bat b/tensorflow/contrib/cmake/make.bat new file mode 100644 index 0000000000000000000000000000000000000000..d52b24e01d6590180106ba6cee2782c99d734ee3 --- /dev/null +++ b/tensorflow/contrib/cmake/make.bat @@ -0,0 +1,38 @@ +%echo off + +cd /d %~dp0 + +if exist _build rd /s /q _build + +mkdir _build +chdir _build + + +rem cmake ../ -G "Visual Studio 15 Win64" -DCMAKE_GENERATOR_TOOLSET=v141,host=x64 -DCMAKE_INSTALL_PREFIX:PATH=.\install + +CALL :NORMALIZEPATH "..\..\..\.." +SET SOURCE_DIR=%RETVAL% + +echo %SOURCE_DIR% + +SET SOURCE_DIR=F:\frameworks\tensorflow\ + +CALL :NORMALIZEPATH "../../../tools/git/gen_git_source.py" +SET SOURCE_PYTHON_SCRIPT=%RETVAL% + +CALL :NORMALIZEPATH "../../../core/util/version_info.cc" +SET SOURCE_VERSION_CC=%RETVAL% + +python %SOURCE_PYTHON_SCRIPT% --raw_generate %SOURCE_VERSION_CC% --source_dir %SOURCE_DIR% --git_tag_override= + +cmake ../ -G "Visual Studio 15 Win64" -DCMAKE_GENERATOR_TOOLSET=v141,host=x64 -DCMAKE_INSTALL_PREFIX:PATH=.\install + +EXIT /B + +:NORMALIZEPATH + SET RETVAL=%~dpfn1 + EXIT /B + + + + \ No newline at end of file diff --git a/tensorflow/contrib/cmake/modules/FindAbseilCpp.cmake b/tensorflow/contrib/cmake/modules/FindAbseilCpp.cmake new file mode 100644 index 0000000000000000000000000000000000000000..d4f8bb1bec9ae8eff58dfe78168d8e71319c85e1 --- /dev/null +++ b/tensorflow/contrib/cmake/modules/FindAbseilCpp.cmake @@ -0,0 +1,72 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +find_path(ABSEIL_CPP_INCLUDE_DIR absl/base/config.h + HINTS "${ABSEIL_CPP_INCLUDE_DIR_HINTS}" + PATHS "$ENV{PROGRAMFILES}" + "$ENV{PROGRAMW6432}" + PATH_SUFFIXES "") + +if(EXISTS "${ABSEIL_CPP_INCLUDE_DIR}" AND NOT "${ABSEIL_CPP_INCLUDE_DIR}" STREQUAL "") + + if(NOT AbseilCpp_FIND_COMPONENTS) + # search all libraries if no COMPONENTS was requested + set(AbseilCpp_FIND_COMPONENTS + "absl_algorithm;absl_any;absl_bad_any_cast" + "absl_bad_optional_access;absl_base absl_container;absl_debugging" + "absl_dynamic_annotations;absl_examine_stack;absl_failure_signal_handler" + "absl_int128;absl_leak_check;absl_malloc_internal;absl_memory;absl_meta" + "absl_numeric;absl_optional;absl_span;absl_spinlock_wait;absl_stack_consumption" + "absl_stacktrace;absl_str_format;absl_strings;absl_symbolize;absl_synchronization" + "absl_throw_delegate;absl_time;absl_utility;str_format_extension_internal" + "str_format_internal;test_instance_tracker_lib") + endif() + + foreach(LIBNAME ${AbseilCpp_FIND_COMPONENTS}) + + unset(ABSEIL_CPP_LIBRARY CACHE) + + find_library(ABSEIL_CPP_LIBRARY + NAMES ${LIBNAME} + HINTS ${ABSEIL_CPP_LIBRARIES_DIR_HINTS}) + + if(ABSEIL_CPP_LIBRARY) + list(APPEND ABSEIL_CPP_LIBRARIES ${ABSEIL_CPP_LIBRARY}) + else() + message(FATAL_ERROR "\n" + "abseil_cpp library \"${LIBNAME}\" not found in system path.\n" + "Please provide locations using: -DABSEIL_CPP_LIBRARIES_DIR_HINTS:STRING=\"PATH\"\n") + endif() + + endforeach() + + unset(LIBNAME CACHE) + unset(ABSEIL_CPP_LIBRARY CACHE) + + set(ABSEIL_CPP_FOUND TRUE) + message(STATUS "Found abseil_cpp libraries") + + set(ABSEIL_CPP_INCLUDE_DIR "${ABSEIL_CPP_INCLUDE_DIR}" CACHE PATH "" FORCE) + mark_as_advanced(ABSEIL_CPP_INCLUDE_DIR) + + set(ABSEIL_CPP_LIBRARIES "${ABSEIL_CPP_LIBRARIES}" CACHE PATH "" FORCE) + mark_as_advanced(ABSEIL_CPP_LIBRARIES) + +else() + + message(FATAL_ERROR "\n" + "abseil_cpp headers not found in system path.\n" + "Please provide locations using: -DABSEIL_CPP_INCLUDE_DIR_HINTS:STRING=\"PATH\"\n") + +endif() diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index 2975b167ec84eee8043b06af5e972330d319d338..6e72670142d560a364350bb4769f1153f884b0f6 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -134,7 +134,6 @@ tensorflow/contrib/cudnn_rnn/python/ops tensorflow/contrib/data tensorflow/contrib/data/python tensorflow/contrib/data/python/kernel_tests -tensorflow/contrib/data/python/kernel_tests/serialization tensorflow/contrib/data/python/ops tensorflow/contrib/decision_trees tensorflow/contrib/decision_trees/proto @@ -206,6 +205,8 @@ tensorflow/contrib/integrate/python tensorflow/contrib/integrate/python/ops tensorflow/contrib/kafka/python tensorflow/contrib/kafka/python/ops +tensorflow/contrib/ignite/python +tensorflow/contrib/ignite/python/ops tensorflow/contrib/keras tensorflow/contrib/keras/api tensorflow/contrib/keras/api/keras diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index 067c299a71cd4ac96878bcf27b4453466785e4ba..7e806685b8448cbd629985cdc00ed1193857abe6 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -258,14 +258,21 @@ add_dependencies(tf_core_lib ${tensorflow_EXTERNAL_DEPENDENCIES} tf_protos_cc) # force_rebuild always runs forcing ${VERSION_INFO_CC} target to run # ${VERSION_INFO_CC} would cache, but it depends on a phony never produced # target. -set(VERSION_INFO_CC ${tensorflow_source_dir}/tensorflow/core/util/version_info.cc) -add_custom_target(force_rebuild_target ALL DEPENDS ${VERSION_INFO_CC}) -add_custom_command(OUTPUT __force_rebuild COMMAND ${CMAKE_COMMAND} -E echo) -add_custom_command(OUTPUT - ${VERSION_INFO_CC} - COMMAND ${PYTHON_EXECUTABLE} ${tensorflow_source_dir}/tensorflow/tools/git/gen_git_source.py - ARGS --raw_generate ${VERSION_INFO_CC} --source_dir ${tensorflow_source_dir} --git_tag_override=${GIT_TAG_OVERRIDE} - DEPENDS __force_rebuild) +# This code forces rebuild every time, not needed as version from git is fetched only once +# move to make.bat which mimicks make.sh + +if (NOT WIN32) + + set(VERSION_INFO_CC ${tensorflow_source_dir}/tensorflow/core/util/version_info.cc) + add_custom_target(force_rebuild_target ALL DEPENDS ${VERSION_INFO_CC}) + add_custom_command(OUTPUT __force_rebuild COMMAND ${CMAKE_COMMAND} -E echo) + add_custom_command(OUTPUT + ${VERSION_INFO_CC} + COMMAND ${PYTHON_EXECUTABLE} ${tensorflow_source_dir}/tensorflow/tools/git/gen_git_source.py + ARGS --raw_generate ${VERSION_INFO_CC} --source_dir ${tensorflow_source_dir} --git_tag_override=${GIT_TAG_OVERRIDE} + DEPENDS __force_rebuild) +endif() + set(tf_version_srcs ${tensorflow_source_dir}/tensorflow/core/util/version_info.cc) ######################################################## diff --git a/tensorflow/contrib/coder/BUILD b/tensorflow/contrib/coder/BUILD index 4bfd753bb1d1fc254c66a4f7eb1d6ac83a40cb70..7f96a103d4cd797bc733a41a673eac492419b4c6 100644 --- a/tensorflow/contrib/coder/BUILD +++ b/tensorflow/contrib/coder/BUILD @@ -13,12 +13,12 @@ load( "//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_custom_op_library", - "tf_custom_op_py_library", "tf_gen_op_libs", "tf_gen_op_wrapper_py", "tf_kernel_library", "tf_py_test", ) +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") cc_library( name = "range_coder", diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py index 873b03580d6f1d9cb25c79cb31989d43cdb8c9a7..f2636e190c25c094dd4ee1370c4728994b1014f5 100644 --- a/tensorflow/contrib/compiler/xla.py +++ b/tensorflow/contrib/compiler/xla.py @@ -272,7 +272,7 @@ def _compile_internal(computation, inputs=None): raise TypeError( 'Supplied computation cannot be called with the specified inputs. You ' 'specified %d inputs: %s, but the computation needs %s' % - (input_arity, str([i.name for i in inputs[0]]), arg_error)) + (input_arity, str([i.name for i in inputs]), arg_error)) cluster_name = ops.get_default_graph().unique_name('cluster') pivot = control_flow_ops.no_op(name=cluster_name + '/pivot') @@ -606,8 +606,8 @@ class _ModelFnWrapper(object): def estimator_model_fn(target_model_fn=None): """estimator_model_fn decorates a model_fn to be compiled for execution. - Currently only it only works with `TPUEstimator`. If you need to use it with - base `Estimator`, please add `tf.enable_resource_variables()` at beginning of + Currently it only works with `TPUEstimator`. If you need to use it with base + `Estimator`, please add `tf.enable_resource_variables()` at the beginning of your program. Example 1, decorating model_fn: diff --git a/tensorflow/contrib/copy_graph/python/util/copy_test.py b/tensorflow/contrib/copy_graph/python/util/copy_test.py index ba97c7845635596c3f4f849044b6707ec43f5bbf..4d8651a79fde9b876d4fdd9b050e71d2eb7c893d 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_test.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_test.py @@ -26,15 +26,16 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test -graph1 = ops.Graph() -graph2 = ops.Graph() - class CopyVariablesTest(test.TestCase): + def setUp(self): + self.graph1 = ops.Graph() + self.graph2 = ops.Graph() + def testVariableCopy(self): - with graph1.as_default(): + with self.graph1.as_default(): #Define a Variable in graph1 some_var = variables.VariableV1(2) #Initialize session @@ -43,13 +44,15 @@ class CopyVariablesTest(test.TestCase): variables.global_variables_initializer().run(session=sess1) #Make a copy of some_var in the defsult scope in graph2 - copy1 = copy_elements.copy_variable_to_graph(some_var, graph2) + copy1 = copy_elements.copy_variable_to_graph(some_var, self.graph2) #Make another copy with different scope - copy2 = copy_elements.copy_variable_to_graph(some_var, graph2, "test_scope") + copy2 = copy_elements.copy_variable_to_graph(some_var, + self.graph2, + "test_scope") #Initialize both the copies - with graph2.as_default(): + with self.graph2.as_default(): #Initialize Session sess2 = session_lib.Session() #Initialize the Variables @@ -67,9 +70,13 @@ class CopyVariablesTest(test.TestCase): class CopyOpsTest(test.TestCase): + def setUp(self): + self.graph1 = ops.Graph() + self.graph2 = ops.Graph() + def testOpsCopy(self): - with graph1.as_default(): + with self.graph1.as_default(): #Initialize a basic expression y = ax + b x = array_ops.placeholder("float") a = variables.VariableV1(3.0) @@ -82,21 +89,21 @@ class CopyOpsTest(test.TestCase): variables.global_variables_initializer().run(session=sess1) #First, initialize a as a Variable in graph2 - a1 = copy_elements.copy_variable_to_graph(a, graph2) + a1 = copy_elements.copy_variable_to_graph(a, self.graph2) #Initialize a1 in graph2 - with graph2.as_default(): + with self.graph2.as_default(): #Initialize session sess2 = session_lib.Session() #Initialize the Variable variables.global_variables_initializer().run(session=sess2) #Initialize a copy of y in graph2 - y1 = copy_elements.copy_op_to_graph(y, graph2, [a1]) + y1 = copy_elements.copy_op_to_graph(y, self.graph2, [a1]) #Now that y has been copied, x must be copied too. #Get that instance - x1 = copy_elements.get_copied_op(x, graph2) + x1 = copy_elements.get_copied_op(x, self.graph2) #Compare values of y & y1 for a sample input #and check if they match diff --git a/tensorflow/contrib/data/README.md b/tensorflow/contrib/data/README.md index 848782e8d89b8670caf3b45de4912a7e0855c102..90be7a66cac6746e29a121fe6a772a94e04e8f02 100644 --- a/tensorflow/contrib/data/README.md +++ b/tensorflow/contrib/data/README.md @@ -1,10 +1,12 @@ `tf.contrib.data` API ===================== -NOTE: The `tf.contrib.data` module has been deprecated. Use `tf.data` instead. -We are continuing to support existing code using the `tf.contrib.data` APIs in -the current version of TensorFlow, but will eventually remove support. The -`tf.data` APIs are subject to backwards compatibility guarantees. +NOTE: The `tf.contrib.data` module has been deprecated. Use `tf.data` instead, +or `tf.data.experimental` for the experimental transformations previously hosted +in this module. We are continuing to support existing code using the +`tf.contrib.data` APIs in the current version of TensorFlow, but will eventually +remove support. The non-experimental `tf.data` APIs are subject to backwards +compatibility guarantees. Porting your code to `tf.data` ------------------------------ @@ -25,13 +27,13 @@ instead apply them using `Dataset.apply()` transformation. The full list of changes is as follows: * `dataset.dense_to_sparse_batch(...)` is now - `dataset.apply(tf.contrib.data.dense_to_sparse_batch(...)`. + `dataset.apply(tf.data.experimental.dense_to_sparse_batch(...)`. * `dataset.enumerate(...)` is now - `dataset.apply(tf.contrib.data.enumerate_dataset(...))`. + `dataset.apply(tf.data.experimental.enumerate_dataset(...))`. * `dataset.group_by_window(...)` is now - `dataset.apply(tf.contrib.data.group_by_window(...))`. + `dataset.apply(tf.data.experimental.group_by_window(...))`. * `dataset.ignore_errors()` is now - `dataset.apply(tf.contrib.data.ignore_errors())`. + `dataset.apply(tf.data.experimental.ignore_errors())`. * `dataset.unbatch()` is now `dataset.apply(tf.contrib.data.unbatch())`. The `Dataset.make_dataset_resource()` and `Iterator.dispose_op()` methods have diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 3cb51279c356f5fe79da98acb5cf481b4d76f6b8..c3d3e981fa10144ed94217cf948b485a7c2bced8 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -96,10 +96,6 @@ from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datase from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator - -# Optimization constant that can be used to enable auto-tuning. -from tensorflow.contrib.data.python.ops.optimization import AUTOTUNE - from tensorflow.contrib.data.python.ops.parsing_ops import parse_example_dataset from tensorflow.contrib.data.python.ops.prefetching_ops import copy_to_device from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device @@ -114,11 +110,12 @@ from tensorflow.contrib.data.python.ops.resampling import rejection_resample from tensorflow.contrib.data.python.ops.scan_ops import scan from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch -from tensorflow.contrib.data.python.ops.stats_ops import latency_stats -from tensorflow.contrib.data.python.ops.stats_ops import set_stats_aggregator -from tensorflow.contrib.data.python.ops.stats_ops import StatsAggregator from tensorflow.contrib.data.python.ops.unique import unique from tensorflow.contrib.data.python.ops.writers import TFRecordWriter + +# Optimization constant that can be used to enable auto-tuning. +from tensorflow.python.data.experimental.ops.optimization import AUTOTUNE + from tensorflow.python.data.ops.iterator_ops import get_next_as_optional from tensorflow.python.data.ops.optional_ops import Optional # pylint: enable=unused-import diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 33784afa3fdd739c3572278bff294b270feabd00..42f538b4ba1cb5b6a9a717e94f4e688cae56b056 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -8,51 +8,17 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "py_test") py_test( - name = "batch_dataset_op_test", - size = "medium", - srcs = ["batch_dataset_op_test.py"], + name = "assert_element_shape_test", + srcs = ["assert_element_shape_test.py"], srcs_version = "PY2AND3", - tags = [ - "no_oss", # (b/79552534) - "no_pip", - ], deps = [ "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", "//tensorflow/python:script_ops", - "//tensorflow/python:session", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:string_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:util", - "//tensorflow/python/data/kernel_tests:test_base", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - ], -) - -py_test( - name = "bucketing_test", - size = "medium", - srcs = ["bucketing_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:grouping", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python/data/kernel_tests:test_base", @@ -61,147 +27,6 @@ py_test( ], ) -py_test( - name = "csv_dataset_op_test", - size = "medium", - srcs = ["csv_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - "//tensorflow/contrib/data/python/ops:error_ops", - "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:platform", - "//tensorflow/python:platform_test", - "//tensorflow/python:session", - "//tensorflow/python/data/kernel_tests:test_base", - "//tensorflow/python/data/ops:readers", - "//tensorflow/python/eager:context", - "//third_party/py/numpy", - ], -) - -py_test( - name = "dataset_constructor_op_test", - size = "medium", - srcs = ["dataset_constructor_op_test.py"], - srcs_version = "PY2AND3", - tags = [ - "manual", - "nomac", # b/62040583 - ], - deps = [ - "//tensorflow/contrib/data/python/ops:batching", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python/data/kernel_tests:test_base", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - ], -) - -py_test( - name = "directed_interleave_dataset_test", - size = "medium", - srcs = ["directed_interleave_dataset_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:interleave_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:random_seed", - "//tensorflow/python/data/kernel_tests:test_base", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - ], -) - -py_test( - name = "get_single_element_test", - size = "small", - srcs = ["get_single_element_test.py"], - deps = [ - "//tensorflow/contrib/data/python/ops:get_single_element", - "//tensorflow/contrib/data/python/ops:grouping", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python/data/kernel_tests:test_base", - "//tensorflow/python/data/ops:dataset_ops", - "@absl_py//absl/testing:parameterized", - ], -) - -py_test( - name = "indexed_dataset_ops_test", - srcs = ["indexed_dataset_ops_test.py"], - deps = [ - "//tensorflow/contrib/data/python/ops:indexed_dataset_ops", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:experimental_dataset_ops_gen", - "//tensorflow/python/data/kernel_tests:test_base", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - ], -) - -py_test( - name = "interleave_dataset_op_test", - size = "medium", - srcs = ["interleave_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_oss", - "no_pip", - "notap", - ], - deps = [ - "//tensorflow/contrib/data/python/ops:interleave_ops", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:math_ops", - "//tensorflow/python:script_ops", - "//tensorflow/python:sparse_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python/data/kernel_tests:test_base", - "//tensorflow/python/data/ops:dataset_ops", - "@six_archive//:six", - ], -) - -py_test( - name = "iterator_ops_test", - size = "small", - srcs = ["iterator_ops_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - "//tensorflow/contrib/data/python/ops:iterator_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", - "//tensorflow/python:variables", - "//tensorflow/python/data/kernel_tests:test_base", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/estimator:estimator_py", - ], -) - py_test( name = "lmdb_dataset_op_test", size = "medium", @@ -229,252 +54,18 @@ py_test( ) py_test( - name = "map_dataset_op_test", - size = "medium", - srcs = ["map_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - "noasan", # times out - "optonly", - ], - deps = [ - "//tensorflow/contrib/data/python/ops:batching", - "//tensorflow/contrib/data/python/ops:error_ops", - "//tensorflow/contrib/data/python/ops:optimization", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:io_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:util", - "//tensorflow/python/data/kernel_tests:test_base", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - ], -) - -py_test( - name = "filter_dataset_op_test", - size = "medium", - srcs = ["filter_dataset_op_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:optimization", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:io_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:util", - "//tensorflow/python/data/kernel_tests:test_base", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - ], -) - -py_test( - name = "map_defun_op_test", + name = "reduce_dataset_test", size = "small", - srcs = ["map_defun_op_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], + srcs = ["reduce_dataset_test.py"], deps = [ - "//tensorflow/contrib/data/python/ops:map_defun", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:function", - "//tensorflow/python:functional_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:session", - "//tensorflow/python/data/kernel_tests:test_base", - ], -) - -py_test( - name = "parsing_ops_test", - size = "small", - srcs = ["parsing_ops_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:parsing_ops", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:platform", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python/data/kernel_tests:test_base", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//third_party/py/numpy", - ], -) - -cuda_py_test( - name = "prefetching_ops_test", - size = "small", - srcs = ["prefetching_ops_test.py"], - additional_deps = [ - "//tensorflow/contrib/data/python/ops:prefetching_ops", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python/data/kernel_tests:test_base", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:function", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python/compat:compat", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - ], - tags = ["no_windows_gpu"], -) - -py_test( - name = "range_dataset_op_test", - size = "small", - srcs = ["range_dataset_op_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:counter", - "//tensorflow/contrib/data/python/ops:enumerate_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:tensor_shape", - "//tensorflow/python/data/kernel_tests:test_base", - "//tensorflow/python/data/ops:dataset_ops", - ], -) - -py_library( - name = "reader_dataset_ops_test_base", - testonly = 1, - srcs = [ - "reader_dataset_ops_test_base.py", - ], - srcs_version = "PY2AND3", - visibility = [ - "//tensorflow/contrib/data/python/kernel_tests:__pkg__", - "//tensorflow/contrib/data/python/kernel_tests/serialization:__pkg__", - ], - deps = [ - "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/core:protos_all_py", + "//tensorflow/contrib/data/python/ops:get_single_element", + "//tensorflow/contrib/data/python/ops:grouping", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", - "//tensorflow/python:lib", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:util", - "//tensorflow/python/data/kernel_tests:test_base", - "//tensorflow/python/data/ops:iterator_ops", - "//tensorflow/python/data/ops:readers", - ], -) - -py_test( - name = "reader_dataset_ops_test", - size = "medium", - srcs = ["reader_dataset_ops_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":reader_dataset_ops_test_base", - "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:string_ops", - "//tensorflow/python/data/ops:readers", - "//tensorflow/python/data/util:nest", - "//third_party/py/numpy", - ], -) - -py_test( - name = "resample_test", - size = "medium", - srcs = ["resample_test.py"], - shard_count = 2, - srcs_version = "PY2AND3", - tags = [ - "noasan", - "optonly", - ], - deps = [ - "//tensorflow/contrib/data/python/ops:resampling", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:string_ops", - "//tensorflow/python:util", "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", - "@six_archive//:six", - ], -) - -py_test( - name = "scan_dataset_op_test", - size = "small", - srcs = ["scan_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - "//tensorflow/contrib/data/python/ops:scan_ops", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python/data/kernel_tests:test_base", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/eager:context", - "//third_party/py/numpy", - ], -) - -py_test( - name = "shuffle_dataset_op_test", - size = "medium", - srcs = ["shuffle_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - "optonly", - ], - deps = [ - "//tensorflow/contrib/data/python/ops:shuffle_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python/data/kernel_tests:test_base", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", ], ) @@ -496,142 +87,3 @@ py_test( "@absl_py//absl/testing:parameterized", ], ) - -py_library( - name = "sql_dataset_op_test_base", - srcs = ["sql_dataset_op_test_base.py"], - srcs_version = "PY2AND3", - visibility = [ - "//tensorflow/contrib/data/python/kernel_tests:__pkg__", - "//tensorflow/contrib/data/python/kernel_tests/serialization:__pkg__", - ], - deps = [ - "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python/data/kernel_tests:test_base", - "@org_sqlite//:python", - ], -) - -py_test( - name = "sql_dataset_op_test", - size = "small", - srcs = ["sql_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":sql_dataset_op_test_base", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - ], -) - -py_test( - name = "stats_dataset_ops_test", - size = "medium", - srcs = ["stats_dataset_ops_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":reader_dataset_ops_test_base", - ":stats_dataset_test_base", - "//tensorflow/contrib/data/python/ops:stats_ops", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - ], -) - -py_library( - name = "stats_dataset_test_base", - srcs = ["stats_dataset_test_base.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python/data/kernel_tests:test_base", - ], -) - -py_test( - name = "threadpool_dataset_ops_test", - size = "small", - srcs = ["threadpool_dataset_ops_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - "//tensorflow/contrib/data/python/ops:threadpool", - "//tensorflow/contrib/data/python/ops:unique", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:script_ops", - "//tensorflow/python/data/kernel_tests:test_base", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - ], -) - -py_test( - name = "unique_dataset_op_test", - size = "small", - srcs = ["unique_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - "//tensorflow/contrib/data/python/ops:unique", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:util", - "//tensorflow/python/data/kernel_tests:test_base", - "//tensorflow/python/data/ops:dataset_ops", - ], -) - -py_test( - name = "window_dataset_op_test", - size = "medium", - srcs = ["window_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], - deps = [ - "//tensorflow/contrib/data/python/ops:batching", - "//tensorflow/contrib/data/python/ops:grouping", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:math_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python/data/kernel_tests:test_base", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - ], -) - -py_test( - name = "writer_ops_test", - size = "small", - srcs = ["writer_ops_test.py"], - deps = [ - "//tensorflow/contrib/data/python/ops:writers", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:lib", - "//tensorflow/python:util", - "//tensorflow/python/data/kernel_tests:test_base", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:readers", - ], -) diff --git a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0456463a1928cf226010670b90a5d574579e0411 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py @@ -0,0 +1,226 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import script_ops +from tensorflow.python.platform import test + + +class AssertElementShapeTest(test_base.DatasetTestBase): + + def test_assert_element_shape(self): + + def create_dataset(_): + return (array_ops.ones(2, dtype=dtypes.float32), + array_ops.zeros((3, 4), dtype=dtypes.int32)) + + dataset = dataset_ops.Dataset.range(5).map(create_dataset) + expected_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 4))) + self.assertEqual(expected_shapes, dataset.output_shapes) + + result = dataset.apply(batching.assert_element_shape(expected_shapes)) + self.assertEqual(expected_shapes, result.output_shapes) + + iterator = result.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.cached_session() as sess: + sess.run(init_op) + for _ in range(5): + sess.run(get_next) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def test_assert_wrong_element_shape(self): + + def create_dataset(_): + return (array_ops.ones(2, dtype=dtypes.float32), + array_ops.zeros((3, 4), dtype=dtypes.int32)) + + dataset = dataset_ops.Dataset.range(3).map(create_dataset) + wrong_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 10))) + with self.assertRaises(ValueError): + dataset.apply(batching.assert_element_shape(wrong_shapes)) + + def test_assert_element_shape_on_unknown_shape_dataset(self): + + def create_unknown_shape_dataset(x): + return script_ops.py_func( + lambda _: ( # pylint: disable=g-long-lambda + np.ones(2, dtype=np.float32), + np.zeros((3, 4), dtype=np.int32)), + [x], + [dtypes.float32, dtypes.int32]) + + dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset) + unknown_shapes = (tensor_shape.TensorShape(None), + tensor_shape.TensorShape(None)) + self.assertEqual(unknown_shapes, dataset.output_shapes) + + expected_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 4))) + result = dataset.apply(batching.assert_element_shape(expected_shapes)) + self.assertEqual(expected_shapes, result.output_shapes) + + iterator = result.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.cached_session() as sess: + sess.run(init_op) + for _ in range(5): + sess.run(get_next) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def test_assert_wrong_element_shape_on_unknown_shape_dataset(self): + + def create_unknown_shape_dataset(x): + return script_ops.py_func( + lambda _: ( # pylint: disable=g-long-lambda + np.ones(2, dtype=np.float32), + np.zeros((3, 4), dtype=np.int32)), + [x], + [dtypes.float32, dtypes.int32]) + + dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset) + unknown_shapes = (tensor_shape.TensorShape(None), + tensor_shape.TensorShape(None)) + self.assertEqual(unknown_shapes, dataset.output_shapes) + + wrong_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 10))) + iterator = ( + dataset.apply(batching.assert_element_shape(wrong_shapes)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + with self.cached_session() as sess: + sess.run(init_op) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(get_next) + + def test_assert_partial_element_shape(self): + + def create_dataset(_): + return (array_ops.ones(2, dtype=dtypes.float32), + array_ops.zeros((3, 4), dtype=dtypes.int32)) + + dataset = dataset_ops.Dataset.range(5).map(create_dataset) + partial_expected_shape = ( + tensor_shape.TensorShape(None), # Unknown shape + tensor_shape.TensorShape((None, 4))) # Partial shape + result = dataset.apply( + batching.assert_element_shape(partial_expected_shape)) + # Partial shapes are merged with actual shapes: + actual_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 4))) + self.assertEqual(actual_shapes, result.output_shapes) + + iterator = result.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.cached_session() as sess: + sess.run(init_op) + for _ in range(5): + sess.run(get_next) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def test_assert_wrong_partial_element_shape(self): + + def create_dataset(_): + return (array_ops.ones(2, dtype=dtypes.float32), + array_ops.zeros((3, 4), dtype=dtypes.int32)) + + dataset = dataset_ops.Dataset.range(3).map(create_dataset) + wrong_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((None, 10))) + with self.assertRaises(ValueError): + dataset.apply(batching.assert_element_shape(wrong_shapes)) + + def test_assert_partial_element_shape_on_unknown_shape_dataset(self): + + def create_unknown_shape_dataset(x): + return script_ops.py_func( + lambda _: ( # pylint: disable=g-long-lambda + np.ones(2, dtype=np.float32), + np.zeros((3, 4), dtype=np.int32)), + [x], + [dtypes.float32, dtypes.int32]) + + dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset) + unknown_shapes = (tensor_shape.TensorShape(None), + tensor_shape.TensorShape(None)) + self.assertEqual(unknown_shapes, dataset.output_shapes) + + expected_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((None, 4))) + result = dataset.apply(batching.assert_element_shape(expected_shapes)) + self.assertEqual(expected_shapes, result.output_shapes) + + iterator = result.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.cached_session() as sess: + sess.run(init_op) + for _ in range(5): + sess.run(get_next) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def test_assert_wrong_partial_element_shape_on_unknown_shape_dataset(self): + + def create_unknown_shape_dataset(x): + return script_ops.py_func( + lambda _: ( # pylint: disable=g-long-lambda + np.ones(2, dtype=np.float32), + np.zeros((3, 4), dtype=np.int32)), + [x], + [dtypes.float32, dtypes.int32]) + + dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset) + unknown_shapes = (tensor_shape.TensorShape(None), + tensor_shape.TensorShape(None)) + self.assertEqual(unknown_shapes, dataset.output_shapes) + + wrong_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((None, 10))) + iterator = ( + dataset.apply(batching.assert_element_shape(wrong_shapes)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + with self.cached_session() as sess: + sess.run(init_op) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py deleted file mode 100644 index ae401f786cffde41d24097a09f6abcb1a43833e8..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ /dev/null @@ -1,824 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for the experimental input pipeline ops.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import random - -import numpy as np - -from tensorflow.contrib.data.python.ops import grouping -from tensorflow.python.data.kernel_tests import test_base -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import string_ops -from tensorflow.python.platform import test - - -class GroupByReducerTest(test_base.DatasetTestBase): - - def checkResults(self, dataset, shapes, values): - self.assertEqual(shapes, dataset.output_shapes) - get_next = dataset.make_one_shot_iterator().get_next() - with self.cached_session() as sess: - for expected in values: - got = sess.run(get_next) - self.assertEqual(got, expected) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testSum(self): - reducer = grouping.Reducer( - init_func=lambda _: np.int64(0), - reduce_func=lambda x, y: x + y, - finalize_func=lambda x: x) - for i in range(1, 11): - dataset = dataset_ops.Dataset.range(2 * i).apply( - grouping.group_by_reducer(lambda x: x % 2, reducer)) - self.checkResults( - dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i]) - - def testAverage(self): - - def reduce_fn(x, y): - return (x[0] * x[1] + math_ops.cast(y, dtypes.float32)) / ( - x[1] + 1), x[1] + 1 - - reducer = grouping.Reducer( - init_func=lambda _: (0.0, 0.0), - reduce_func=reduce_fn, - finalize_func=lambda x, _: x) - for i in range(1, 11): - dataset = dataset_ops.Dataset.range(2 * i).apply( - grouping.group_by_reducer( - lambda x: math_ops.cast(x, dtypes.int64) % 2, reducer)) - self.checkResults( - dataset, shapes=tensor_shape.scalar(), values=[i - 1, i]) - - def testConcat(self): - components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray) - reducer = grouping.Reducer( - init_func=lambda x: "", - reduce_func=lambda x, y: x + y[0], - finalize_func=lambda x: x) - for i in range(1, 11): - dataset = dataset_ops.Dataset.zip( - (dataset_ops.Dataset.from_tensor_slices(components), - dataset_ops.Dataset.range(2 * i))).apply( - grouping.group_by_reducer(lambda x, y: y % 2, reducer)) - self.checkResults( - dataset, - shapes=tensor_shape.scalar(), - values=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]]) - - def testSparseSum(self): - def _sparse(i): - return sparse_tensor.SparseTensorValue( - indices=np.array([[0, 0]]), - values=(i * np.array([1], dtype=np.int64)), - dense_shape=np.array([1, 1])) - - reducer = grouping.Reducer( - init_func=lambda _: _sparse(np.int64(0)), - reduce_func=lambda x, y: _sparse(x.values[0] + y.values[0]), - finalize_func=lambda x: x.values[0]) - for i in range(1, 11): - dataset = dataset_ops.Dataset.range(2 * i).map(_sparse).apply( - grouping.group_by_reducer(lambda x: x.values[0] % 2, reducer)) - self.checkResults( - dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i]) - - def testChangingStateShape(self): - - def reduce_fn(x, _): - # Statically known rank, but dynamic length. - larger_dim = array_ops.concat([x[0], x[0]], 0) - # Statically unknown rank. - larger_rank = array_ops.expand_dims(x[1], 0) - return larger_dim, larger_rank - - reducer = grouping.Reducer( - init_func=lambda x: ([0], 1), - reduce_func=reduce_fn, - finalize_func=lambda x, y: (x, y)) - - for i in range(1, 11): - dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply( - grouping.group_by_reducer(lambda x: x, reducer)) - self.assertEqual([None], dataset.output_shapes[0].as_list()) - self.assertIs(None, dataset.output_shapes[1].ndims) - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - with self.cached_session() as sess: - x, y = sess.run(get_next) - self.assertAllEqual([0] * (2**i), x) - self.assertAllEqual(np.array(1, ndmin=i), y) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testTypeMismatch(self): - reducer = grouping.Reducer( - init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32), - reduce_func=lambda x, y: constant_op.constant(1, dtype=dtypes.int64), - finalize_func=lambda x: x) - - dataset = dataset_ops.Dataset.range(10) - with self.assertRaisesRegexp( - TypeError, - "The element types for the new state must match the initial state."): - dataset.apply( - grouping.group_by_reducer(lambda _: np.int64(0), reducer)) - - # TODO(b/78665031): Remove once non-scalar keys are supported. - def testInvalidKeyShape(self): - reducer = grouping.Reducer( - init_func=lambda x: np.int64(0), - reduce_func=lambda x, y: x + y, - finalize_func=lambda x: x) - - dataset = dataset_ops.Dataset.range(10) - with self.assertRaisesRegexp( - ValueError, "`key_func` must return a single tf.int64 tensor."): - dataset.apply( - grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer)) - - # TODO(b/78665031): Remove once non-int64 keys are supported. - def testInvalidKeyType(self): - reducer = grouping.Reducer( - init_func=lambda x: np.int64(0), - reduce_func=lambda x, y: x + y, - finalize_func=lambda x: x) - - dataset = dataset_ops.Dataset.range(10) - with self.assertRaisesRegexp( - ValueError, "`key_func` must return a single tf.int64 tensor."): - dataset.apply( - grouping.group_by_reducer(lambda _: "wrong", reducer)) - - def testTuple(self): - def init_fn(_): - return np.array([], dtype=np.int64), np.int64(0) - - def reduce_fn(state, value): - s1, s2 = state - v1, v2 = value - return array_ops.concat([s1, [v1]], 0), s2 + v2 - - def finalize_fn(s1, s2): - return s1, s2 - - reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn) - dataset = dataset_ops.Dataset.zip( - (dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply( - grouping.group_by_reducer(lambda x, y: np.int64(0), reducer)) - get_next = dataset.make_one_shot_iterator().get_next() - with self.cached_session() as sess: - x, y = sess.run(get_next) - self.assertAllEqual(x, np.asarray([x for x in range(10)])) - self.assertEqual(y, 45) - - -class GroupByWindowTest(test_base.DatasetTestBase): - - def testSimple(self): - components = np.random.randint(100, size=(200,)).astype(np.int64) - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x) - .apply( - grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), - 4)).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - counts = [] - with self.assertRaises(errors.OutOfRangeError): - while True: - result = sess.run(get_next) - self.assertTrue( - all(x % 2 == 0 - for x in result) or all(x % 2 == 1) - for x in result) - counts.append(result.shape[0]) - - self.assertEqual(len(components), sum(counts)) - num_full_batches = len([c for c in counts if c == 4]) - self.assertGreaterEqual(num_full_batches, 24) - self.assertTrue(all(c == 4 for c in counts[:num_full_batches])) - - def testImmediateOutput(self): - components = np.array( - [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64) - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply( - grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), - 4)).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - # The input is infinite, so this test demonstrates that: - # 1. We produce output without having to consume the entire input, - # 2. Different buckets can produce output at different rates, and - # 3. For deterministic input, the output is deterministic. - for _ in range(3): - self.assertAllEqual([0, 0, 0, 0], sess.run(get_next)) - self.assertAllEqual([1, 1, 1, 1], sess.run(get_next)) - self.assertAllEqual([2, 2, 2, 2], sess.run(get_next)) - self.assertAllEqual([0, 0, 0, 0], sess.run(get_next)) - - def testSmallGroups(self): - components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64) - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components).apply( - grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), - 4)).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - self.assertAllEqual([0, 0, 0, 0], sess.run(get_next)) - self.assertAllEqual([1, 1, 1, 1], sess.run(get_next)) - # The small outputs at the end are deterministically produced in key - # order. - self.assertAllEqual([0, 0, 0], sess.run(get_next)) - self.assertAllEqual([1], sess.run(get_next)) - - def testEmpty(self): - iterator = ( - dataset_ops.Dataset.range(4).apply( - grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0)) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - "Window size must be greater than zero, but got 0."): - print(sess.run(get_next)) - - def testReduceFuncError(self): - components = np.random.randint(100, size=(200,)).astype(np.int64) - - def reduce_func(_, xs): - # Introduce an incorrect padded shape that cannot (currently) be - # detected at graph construction time. - return xs.padded_batch( - 4, - padded_shapes=(tensor_shape.TensorShape([]), - constant_op.constant([5], dtype=dtypes.int64) * -1)) - - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply( - grouping.group_by_window(lambda x, _: x % 2, reduce_func, - 32)).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - with self.assertRaises(errors.InvalidArgumentError): - sess.run(get_next) - - def testConsumeWindowDatasetMoreThanOnce(self): - components = np.random.randint(50, size=(200,)).astype(np.int64) - - def reduce_func(key, window): - # Apply two different kinds of padding to the input: tight - # padding, and quantized (to a multiple of 10) padding. - return dataset_ops.Dataset.zip(( - window.padded_batch( - 4, padded_shapes=tensor_shape.TensorShape([None])), - window.padded_batch( - 4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])), - )) - - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x)) - .apply(grouping.group_by_window( - lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64), - reduce_func, 4)) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - counts = [] - with self.assertRaises(errors.OutOfRangeError): - while True: - tight_result, multiple_of_10_result = sess.run(get_next) - self.assertEqual(0, multiple_of_10_result.shape[1] % 10) - self.assertAllEqual(tight_result, - multiple_of_10_result[:, :tight_result.shape[1]]) - counts.append(tight_result.shape[0]) - self.assertEqual(len(components), sum(counts)) - - -# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py. -# Currently, they use a constant batch size, though should be made to use a -# different batch size per key. -class BucketTest(test_base.DatasetTestBase): - - def _dynamicPad(self, bucket, window, window_size): - # TODO(mrry): To match `tf.contrib.training.bucket()`, implement a - # generic form of padded_batch that pads every component - # dynamically and does not rely on static shape information about - # the arguments. - return dataset_ops.Dataset.zip( - (dataset_ops.Dataset.from_tensors(bucket), - window.padded_batch( - 32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape( - [None]), tensor_shape.TensorShape([3]))))) - - def testSingleBucket(self): - - def _map_fn(v): - return (v, array_ops.fill([v], v), - array_ops.fill([3], string_ops.as_string(v))) - - input_dataset = ( - dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn)) - - bucketed_dataset = input_dataset.apply( - grouping.group_by_window( - lambda x, y, z: 0, - lambda k, bucket: self._dynamicPad(k, bucket, 32), 32)) - - iterator = bucketed_dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - - which_bucket, bucketed_values = sess.run(get_next) - - self.assertEqual(0, which_bucket) - - expected_scalar_int = np.arange(32, dtype=np.int64) - expected_unk_int64 = np.zeros((32, 31)).astype(np.int64) - for i in range(32): - expected_unk_int64[i, :i] = i - expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T - - self.assertAllEqual(expected_scalar_int, bucketed_values[0]) - self.assertAllEqual(expected_unk_int64, bucketed_values[1]) - self.assertAllEqual(expected_vec3_str, bucketed_values[2]) - - def testEvenOddBuckets(self): - - def _map_fn(v): - return (v, array_ops.fill([v], v), - array_ops.fill([3], string_ops.as_string(v))) - - input_dataset = ( - dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn)) - - bucketed_dataset = input_dataset.apply( - grouping.group_by_window( - lambda x, y, z: math_ops.cast(x % 2, dtypes.int64), - lambda k, bucket: self._dynamicPad(k, bucket, 32), 32)) - - iterator = bucketed_dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - - # Get two minibatches (one containing even values, one containing odds) - which_bucket_even, bucketed_values_even = sess.run(get_next) - which_bucket_odd, bucketed_values_odd = sess.run(get_next) - - # Count number of bucket_tensors. - self.assertEqual(3, len(bucketed_values_even)) - self.assertEqual(3, len(bucketed_values_odd)) - - # Ensure bucket 0 was used for all minibatch entries. - self.assertAllEqual(0, which_bucket_even) - self.assertAllEqual(1, which_bucket_odd) - - # Test the first bucket outputted, the events starting at 0 - expected_scalar_int = np.arange(0, 32 * 2, 2, dtype=np.int64) - expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64) - for i in range(0, 32): - expected_unk_int64[i, :2 * i] = 2 * i - expected_vec3_str = np.vstack( - 3 * [np.arange(0, 32 * 2, 2).astype(bytes)]).T - - self.assertAllEqual(expected_scalar_int, bucketed_values_even[0]) - self.assertAllEqual(expected_unk_int64, bucketed_values_even[1]) - self.assertAllEqual(expected_vec3_str, bucketed_values_even[2]) - - # Test the second bucket outputted, the odds starting at 1 - expected_scalar_int = np.arange(1, 32 * 2 + 1, 2, dtype=np.int64) - expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64) - for i in range(0, 32): - expected_unk_int64[i, :2 * i + 1] = 2 * i + 1 - expected_vec3_str = np.vstack( - 3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T - - self.assertAllEqual(expected_scalar_int, bucketed_values_odd[0]) - self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1]) - self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2]) - - def testEvenOddBucketsFilterOutAllOdd(self): - - def _map_fn(v): - return { - "x": v, - "y": array_ops.fill([v], v), - "z": array_ops.fill([3], string_ops.as_string(v)) - } - - def _dynamic_pad_fn(bucket, window, _): - return dataset_ops.Dataset.zip( - (dataset_ops.Dataset.from_tensors(bucket), - window.padded_batch( - 32, { - "x": tensor_shape.TensorShape([]), - "y": tensor_shape.TensorShape([None]), - "z": tensor_shape.TensorShape([3]) - }))) - - input_dataset = ( - dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn) - .filter(lambda d: math_ops.equal(d["x"] % 2, 0))) - - bucketed_dataset = input_dataset.apply( - grouping.group_by_window( - lambda d: math_ops.cast(d["x"] % 2, dtypes.int64), - lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32)) - - iterator = bucketed_dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - - # Get two minibatches ([0, 2, ...] and [64, 66, ...]) - which_bucket0, bucketed_values_even0 = sess.run(get_next) - which_bucket1, bucketed_values_even1 = sess.run(get_next) - - # Ensure that bucket 1 was completely filtered out - self.assertAllEqual(0, which_bucket0) - self.assertAllEqual(0, which_bucket1) - self.assertAllEqual( - np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0["x"]) - self.assertAllEqual( - np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"]) - - def testDynamicWindowSize(self): - components = np.arange(100).astype(np.int64) - - # Key fn: even/odd - # Reduce fn: batches of 5 - # Window size fn: even=5, odd=10 - - def window_size_func(key): - window_sizes = constant_op.constant([5, 10], dtype=dtypes.int64) - return window_sizes[key] - - dataset = dataset_ops.Dataset.from_tensor_slices(components).apply( - grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(20), - None, window_size_func)) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - with self.assertRaises(errors.OutOfRangeError): - batches = 0 - while True: - result = sess.run(get_next) - is_even = all(x % 2 == 0 for x in result) - is_odd = all(x % 2 == 1 for x in result) - self.assertTrue(is_even or is_odd) - expected_batch_size = 5 if is_even else 10 - self.assertEqual(expected_batch_size, result.shape[0]) - batches += 1 - - self.assertEqual(batches, 15) - - -def _element_length_fn(x, y=None): - del y - return array_ops.shape(x)[0] - - -def _to_sparse_tensor(record): - return sparse_tensor.SparseTensor(**record) - - -def _format_record(array, sparse): - if sparse: - return { - "values": array, - "indices": [[i] for i in range(len(array))], - "dense_shape": (len(array),) - } - return array - - -def _get_record_type(sparse): - if sparse: - return { - "values": dtypes.int64, - "indices": dtypes.int64, - "dense_shape": dtypes.int64 - } - return dtypes.int32 - - -def _get_record_shape(sparse): - if sparse: - return { - "values": tensor_shape.TensorShape([None,]), - "indices": tensor_shape.TensorShape([None, 1]), - "dense_shape": tensor_shape.TensorShape([1,]) - } - return tensor_shape.TensorShape([None]) - - -class BucketBySequenceLength(test_base.DatasetTestBase): - - def testBucket(self): - - boundaries = [10, 20, 30] - batch_sizes = [10, 8, 4, 2] - lengths = [8, 13, 25, 35] - - def build_dataset(sparse): - def _generator(): - # Produce 1 batch for each bucket - elements = [] - for batch_size, length in zip(batch_sizes, lengths): - record_len = length - 1 - for _ in range(batch_size): - elements.append([1] * record_len) - record_len = length - random.shuffle(elements) - for el in elements: - yield (_format_record(el, sparse),) - dataset = dataset_ops.Dataset.from_generator( - _generator, - (_get_record_type(sparse),), - (_get_record_shape(sparse),)) - if sparse: - dataset = dataset.map(lambda x: (_to_sparse_tensor(x),)) - return dataset - - def _test_bucket_by_padding(no_padding): - dataset = build_dataset(sparse=no_padding) - dataset = dataset.apply( - grouping.bucket_by_sequence_length( - _element_length_fn, - boundaries, - batch_sizes, - no_padding=no_padding)) - batch, = dataset.make_one_shot_iterator().get_next() - - with self.cached_session() as sess: - batches = [] - for _ in range(4): - batches.append(sess.run(batch)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(batch) - batch_sizes_val = [] - lengths_val = [] - for batch in batches: - shape = batch.dense_shape if no_padding else batch.shape - batch_size = shape[0] - length = shape[1] - batch_sizes_val.append(batch_size) - lengths_val.append(length) - sum_check = batch.values.sum() if no_padding else batch.sum() - self.assertEqual(sum_check, batch_size * length - 1) - self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) - self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) - self.assertEqual(sorted(lengths), sorted(lengths_val)) - - for no_padding in (True, False): - _test_bucket_by_padding(no_padding) - - def testPadToBoundary(self): - - boundaries = [10, 20, 30] - batch_sizes = [10, 8, 4, 2] - lengths = [8, 13, 25] - - def element_gen(): - # Produce 1 batch for each bucket - elements = [] - for batch_size, length in zip(batch_sizes[:-1], lengths): - for _ in range(batch_size): - elements.append([1] * length) - random.shuffle(elements) - for el in elements: - yield (el,) - for _ in range(batch_sizes[-1]): - el = [1] * (boundaries[-1] + 5) - yield (el,) - - element_len = lambda el: array_ops.shape(el)[0] - dataset = dataset_ops.Dataset.from_generator( - element_gen, (dtypes.int64,), ([None],)).apply( - grouping.bucket_by_sequence_length( - element_len, boundaries, batch_sizes, - pad_to_bucket_boundary=True)) - batch, = dataset.make_one_shot_iterator().get_next() - - with self.cached_session() as sess: - batches = [] - for _ in range(3): - batches.append(sess.run(batch)) - with self.assertRaisesOpError("bucket_boundaries"): - sess.run(batch) - batch_sizes_val = [] - lengths_val = [] - for batch in batches: - batch_size = batch.shape[0] - length = batch.shape[1] - batch_sizes_val.append(batch_size) - lengths_val.append(length) - batch_sizes = batch_sizes[:-1] - self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) - self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) - self.assertEqual([boundary - 1 for boundary in sorted(boundaries)], - sorted(lengths_val)) - - def testPadToBoundaryNoExtraneousPadding(self): - - boundaries = [3, 7, 11] - batch_sizes = [2, 2, 2, 2] - lengths = range(1, 11) - - def element_gen(): - for length in lengths: - yield ([1] * length,) - - element_len = lambda element: array_ops.shape(element)[0] - dataset = dataset_ops.Dataset.from_generator( - element_gen, (dtypes.int64,), ([None],)).apply( - grouping.bucket_by_sequence_length( - element_len, boundaries, batch_sizes, - pad_to_bucket_boundary=True)) - batch, = dataset.make_one_shot_iterator().get_next() - - with self.cached_session() as sess: - batches = [] - for _ in range(5): - batches.append(sess.run(batch)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(batch) - - self.assertAllEqual(batches[0], [[1, 0], - [1, 1]]) - self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 0, 0]]) - self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 1]]) - self.assertAllEqual(batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) - self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) - - def testTupleElements(self): - - def build_dataset(sparse): - def _generator(): - text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]] - label = [1, 2, 1, 2] - for x, y in zip(text, label): - yield (_format_record(x, sparse), y) - dataset = dataset_ops.Dataset.from_generator( - generator=_generator, - output_types=(_get_record_type(sparse), dtypes.int32), - output_shapes=(_get_record_shape(sparse), - tensor_shape.TensorShape([]))) - if sparse: - dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y)) - return dataset - - def _test_tuple_elements_by_padding(no_padding): - dataset = build_dataset(sparse=no_padding) - dataset = dataset.apply(grouping.bucket_by_sequence_length( - element_length_func=_element_length_fn, - bucket_batch_sizes=[2, 2, 2], - bucket_boundaries=[0, 8], - no_padding=no_padding)) - shapes = dataset.output_shapes - self.assertEqual([None, None], shapes[0].as_list()) - self.assertEqual([None], shapes[1].as_list()) - - for no_padding in (True, False): - _test_tuple_elements_by_padding(no_padding) - - def testBucketSparse(self): - """Tests bucketing of sparse tensors (case where `no_padding` == True). - - Test runs on following dataset: - [ - [0], - [0, 1], - [0, 1, 2] - ... - [0, ..., max_len - 1] - ] - Sequences are bucketed by length and batched with - `batch_size` < `bucket_size`. - """ - - min_len = 0 - max_len = 100 - batch_size = 7 - bucket_size = 10 - - def _build_dataset(): - input_data = [range(i+1) for i in range(min_len, max_len)] - def generator_fn(): - for record in input_data: - yield _format_record(record, sparse=True) - dataset = dataset_ops.Dataset.from_generator( - generator=generator_fn, - output_types=_get_record_type(sparse=True)) - dataset = dataset.map(_to_sparse_tensor) - return dataset - - def _compute_expected_batches(): - """Computes expected batch outputs and stores in a set.""" - all_expected_sparse_tensors = set() - for bucket_start_len in range(min_len, max_len, bucket_size): - for batch_offset in range(0, bucket_size, batch_size): - batch_start_len = bucket_start_len + batch_offset - batch_end_len = min(batch_start_len + batch_size, - bucket_start_len + bucket_size) - expected_indices = [] - expected_values = [] - for length in range(batch_start_len, batch_end_len): - for val in range(length + 1): - expected_indices.append((length - batch_start_len, val)) - expected_values.append(val) - expected_sprs_tensor = (tuple(expected_indices), - tuple(expected_values)) - all_expected_sparse_tensors.add(expected_sprs_tensor) - return all_expected_sparse_tensors - - def _compute_batches(dataset): - """Computes actual batch outputs of dataset and stores in a set.""" - batch = dataset.make_one_shot_iterator().get_next() - all_sparse_tensors = set() - with self.cached_session() as sess: - with self.assertRaises(errors.OutOfRangeError): - while True: - output = sess.run(batch) - sprs_tensor = (tuple([tuple(idx) for idx in output.indices]), - tuple(output.values)) - all_sparse_tensors.add(sprs_tensor) - return all_sparse_tensors - - dataset = _build_dataset() - boundaries = range(min_len + bucket_size + 1, max_len, bucket_size) - dataset = dataset.apply(grouping.bucket_by_sequence_length( - _element_length_fn, - boundaries, - [batch_size] * (len(boundaries) + 1), - no_padding=True)) - batches = _compute_batches(dataset) - expected_batches = _compute_expected_batches() - self.assertEqual(batches, expected_batches) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e7281d531870c75c638b5c48fa3fc6dc606a3623 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py @@ -0,0 +1,62 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.data.python.ops import get_single_element +from tensorflow.contrib.data.python.ops import grouping +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): + + @parameterized.named_parameters( + ("SumZero", 0), + ("SumOne", 1), + ("SumFive", 5), + ("SumTen", 10), + ) + def testReduceDataset(self, stop): + def init_fn(_): + return np.int64(0) + + def reduce_fn(state, value): + return state + value + + def finalize_fn(state): + return state + + sum_reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn) + + stop_t = array_ops.placeholder(dtypes.int64, shape=[]) + dataset = dataset_ops.Dataset.range(stop_t) + element = get_single_element.reduce_dataset(dataset, sum_reducer) + + with self.cached_session() as sess: + value = sess.run(element, feed_dict={stop_t: stop}) + self.assertEqual(stop * (stop - 1) / 2, value) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py deleted file mode 100644 index 79134c7bc6833b01de1511a4036aa53aae62fe70..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py +++ /dev/null @@ -1,527 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for the experimental input pipeline ops.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from absl.testing import parameterized -import numpy as np - -from tensorflow.contrib.data.python.ops import batching -from tensorflow.contrib.data.python.ops import grouping -from tensorflow.python.data.kernel_tests import test_base -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import sparse_ops -from tensorflow.python.platform import test - - -class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): - - def _structuredDataset(self, structure, shape, dtype): - if structure is None: - return dataset_ops.Dataset.from_tensors( - array_ops.zeros(shape, dtype=dtype)) - else: - return dataset_ops.Dataset.zip( - tuple([ - self._structuredDataset(substructure, shape, dtype) - for substructure in structure - ])) - - def _structuredElement(self, structure, shape, dtype): - if structure is None: - return array_ops.zeros(shape, dtype=dtype) - else: - return tuple([ - self._structuredElement(substructure, shape, dtype) - for substructure in structure - ]) - - def _assertEqual(self, xs, ys): - self.assertEqual(type(xs), type(ys)) - if isinstance(xs, tuple) and isinstance(ys, tuple): - self.assertEqual(len(xs), len(ys)) - for x, y in zip(xs, ys): - self._assertEqual(x, y) - elif isinstance(xs, np.ndarray) and isinstance(ys, np.ndarray): - self.assertAllEqual(xs, ys) - else: - self.assertEqual(xs, ys) - - @parameterized.named_parameters( - ("1", None, np.int32([]), dtypes.bool), - ("2", None, np.int32([]), dtypes.int32), - ("3", None, np.int32([]), dtypes.float32), - ("4", None, np.int32([]), dtypes.string), - ("5", None, np.int32([2]), dtypes.int32), - ("6", None, np.int32([2, 2]), dtypes.int32), - ("7", (None, None, None), np.int32([]), dtypes.int32), - ("8", (None, (None, None)), np.int32([]), dtypes.int32), - ) - def testWindowDatasetFlatMap(self, structure, shape, dtype): - """Tests windowing by chaining it with flat map. - - Args: - structure: the input structure - shape: the input shape - dtype: the input data type - """ - - def fn(*args): - if len(args) == 1 and not isinstance(args[0], tuple): - return args[0] - return dataset_ops.Dataset.zip( - tuple([fn(*arg) if isinstance(arg, tuple) else arg for arg in args])) - - dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply( - grouping.window_dataset(5)).flat_map(fn) - get_next = dataset.make_one_shot_iterator().get_next() - with self.cached_session() as sess: - expected = sess.run(self._structuredElement(structure, shape, dtype)) - for _ in range(5): - actual = sess.run(get_next) - self._assertEqual(expected, actual) - - @parameterized.named_parameters( - ("1", None, np.int32([]), dtypes.bool), - ("2", None, np.int32([]), dtypes.int32), - ("3", None, np.int32([]), dtypes.float32), - ("4", None, np.int32([]), dtypes.string), - ("5", None, np.int32([2]), dtypes.int32), - ("6", None, np.int32([2, 2]), dtypes.int32), - ("7", (None, None, None), np.int32([]), dtypes.int32), - ("8", (None, (None, None)), np.int32([]), dtypes.int32), - ) - def testWindowDatasetBatchDense(self, structure, shape, dtype): - """Tests batching of dense tensor windows. - - Args: - structure: the input structure - shape: the input shape - dtype: the input data type - """ - - def fn(*args): - if len(args) == 1 and not isinstance(args[0], tuple): - return batching.batch_window(args[0]) - - return tuple([ - fn(*arg) if isinstance(arg, tuple) else batching.batch_window(arg) - for arg in args - ]) - - dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply( - grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn)) - get_next = dataset.make_one_shot_iterator().get_next() - with self.cached_session() as sess: - expected = sess.run( - self._structuredElement(structure, np.concatenate( - ([5], shape), axis=0), dtype)) - actual = sess.run(get_next) - self._assertEqual(expected, actual) - - @parameterized.named_parameters( - ("1", np.int32([])), - ("2", np.int32([1])), - ("3", np.int32([1, 2, 3])), - ) - def testWindowDatasetBatchDenseDynamicShape(self, shape): - """Tests batching of dynamically shaped dense tensor windows. - - Args: - shape: the input shape - """ - - shape_t = array_ops.placeholder(dtypes.int32) - dataset = dataset_ops.Dataset.from_tensors( - array_ops.zeros(shape_t)).repeat(5).apply( - grouping.window_dataset(5)).apply( - grouping._map_x_dataset(batching.batch_window)) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - with self.cached_session() as sess: - sess.run(init_op, {shape_t: shape}) - expected = sess.run( - self._structuredElement(None, np.concatenate(([5], shape), axis=0), - dtypes.int32)) - actual = sess.run(get_next) - self._assertEqual(expected, actual) - - def _make_dense_to_sparse_fn(self, is_scalar): - - def dense_to_sparse_scalar(tensor): - indices = [[]] - values = array_ops.expand_dims(tensor, 0) - shape = [] - return sparse_tensor.SparseTensorValue(indices, values, shape) - - def dense_to_sparse_non_scalar(tensor): - indices = array_ops.where(array_ops.ones_like(tensor, dtype=dtypes.bool)) - values = array_ops.gather_nd(tensor, indices) - shape = array_ops.shape(tensor, out_type=dtypes.int64) - return sparse_tensor.SparseTensorValue(indices, values, shape) - - if is_scalar: - return dense_to_sparse_scalar - return dense_to_sparse_non_scalar - - def _structuredSparseDataset(self, structure, shape, dtype): - dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test - if structure is None: - return dataset_ops.Dataset.from_tensors( - dense_to_sparse(array_ops.zeros(shape, dtype=dtype))) - else: - return dataset_ops.Dataset.zip( - tuple([ - self._structuredSparseDataset(substructure, shape, dtype) - for substructure in structure - ])) - - def _structuredSparseElement(self, structure, shape, dtype): - dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test - if structure is None: - return dense_to_sparse(array_ops.zeros(shape, dtype=dtype)) - else: - return tuple([ - self._structuredSparseElement(substructure, shape, dtype) - for substructure in structure - ]) - - @parameterized.named_parameters( - ("1", None, np.int32([]), dtypes.bool), - ("2", None, np.int32([]), dtypes.int32), - ("3", None, np.int32([]), dtypes.float32), - ("4", None, np.int32([]), dtypes.string), - ("5", None, np.int32([2]), dtypes.int32), - ("6", None, np.int32([2, 2]), dtypes.int32), - ("7", (None, None, None), np.int32([]), dtypes.int32), - ("8", (None, (None, None)), np.int32([]), dtypes.int32), - ) - def testWindowDatasetBatchSparse(self, structure, shape, dtype): - """Tests batching of sparse tensor windows. - - Args: - structure: the input structure - shape: the input shape - dtype: the input data type - """ - - def fn(*args): - if len(args) == 1 and not isinstance(args[0], tuple): - return batching.batch_window(args[0]) - - return tuple([ - fn(*arg) if isinstance(arg, tuple) else batching.batch_window(arg) - for arg in args - ]) - - dataset = self._structuredSparseDataset( - structure, shape, dtype).repeat(5).apply( - grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn)) - get_next = dataset.make_one_shot_iterator().get_next() - with self.cached_session() as sess: - expected = sess.run( - self._structuredSparseElement(structure, - np.concatenate(([5], shape), axis=0), - dtype)) - actual = sess.run(get_next) - self._assertEqual(expected, actual) - - @parameterized.named_parameters( - ("1", np.int32([])), - ("2", np.int32([1])), - ("3", np.int32([1, 2, 3])), - ) - def testWindowDatasetBatchSparseDynamicShape(self, shape): - """Tests batching of dynamically shaped sparse tensor windows. - - Args: - shape: the input shape - """ - - shape_t = array_ops.placeholder(dtypes.int32) - dataset = dataset_ops.Dataset.from_tensors(array_ops.zeros(shape_t)).map( - self._make_dense_to_sparse_fn(len(shape) == 0)).repeat(5).apply( # pylint: disable=g-explicit-length-test - grouping.window_dataset(5)).apply( - grouping._map_x_dataset(batching.batch_window)) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - with self.cached_session() as sess: - sess.run(init_op, {shape_t: shape}) - expected = sess.run( - self._structuredSparseElement(None, - np.concatenate(([5], shape), axis=0), - dtypes.int32)) - actual = sess.run(get_next) - self._assertEqual(expected, actual) - - def _structuredRaggedDataset(self, structure, shapes, dtype): - - if structure is None: - return dataset_ops.Dataset.from_tensor_slices(shapes).map( - lambda shape: array_ops.zeros(shape, dtype=dtype)) - else: - return dataset_ops.Dataset.zip( - tuple([ - self._structuredRaggedDataset(substructure, shapes, dtype) - for substructure in structure - ])) - - @parameterized.named_parameters( - ("1", None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]), - ("2", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]), - ("3", None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]), - ("4", None, np.int32([[1], [2], [3]]), dtypes.string, [-1]), - ("5", None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]), - ("6", None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]), - ("7", (None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]), - ("8", (None, - (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]), - ("9", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]), - ("10", None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])), - ) - def testWindowDatasetPaddedBatchDense(self, structure, shapes, dtype, - padded_shape): - """Tests padded batching of dense tensor windows. - - Args: - structure: the input structure - shapes: the input shapes - dtype: the input data type - padded_shape: the shape to pad the output to - """ - - def fn(*args): - if len(args) == 1 and not isinstance(args[0], tuple): - return batching.padded_batch_window(args[0], padded_shape) - - return tuple([ - fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window( - arg, padded_shape) for arg in args - ]) - - dataset = self._structuredRaggedDataset(structure, shapes, dtype).apply( - grouping.window_dataset(len(shapes))).apply( - grouping._map_x_dataset(fn)) - get_next = dataset.make_one_shot_iterator().get_next() - with self.cached_session() as sess: - expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape) - expected = sess.run( - self._structuredElement( - structure, - np.concatenate((np.int32([len(shapes)]), expected_shape)), dtype)) - actual = sess.run(get_next) - self._assertEqual(expected, actual) - - @parameterized.named_parameters( - ("1", np.int32([[1], [2], [3]]), [-1]), - ("2", np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]), - ("3", np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]), - ) - def testWindowDatasetPaddedBatchDenseDynamicShape(self, shapes, padded_shape): - """Tests padded batching of dynamically shaped dense tensor windows. - - Args: - shapes: the input shapes - padded_shape: the shape to pad the output to - """ - - shapes_t = array_ops.placeholder(dtypes.int32) - dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map( - lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).apply( - grouping.window_dataset(len(shapes))).apply( - grouping._map_x_dataset( - lambda x: batching.padded_batch_window(x, padded_shape))) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - with self.cached_session() as sess: - sess.run(init_op, {shapes_t: shapes}) - expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape) - expected = sess.run( - self._structuredElement( - None, np.concatenate((np.int32([len(shapes)]), expected_shape)), - dtypes.int32)) - actual = sess.run(get_next) - self._assertEqual(expected, actual) - - @parameterized.named_parameters( - ("1", np.int32([[1]]), np.int32([0])), - ("2", np.int32([[10], [20]]), np.int32([15])), - ) - def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape): - """Tests invalid padded batching of dense tensor windows. - - Args: - shapes: the input shapes - padded_shape: the shape to pad the output to - """ - - dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map( - lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).apply( - grouping.window_dataset(len(shapes))).apply( - grouping._map_x_dataset( - lambda x: batching.padded_batch_window(x, padded_shape))) - get_next = dataset.make_one_shot_iterator().get_next() - with self.cached_session() as sess: - with self.assertRaises(errors.InvalidArgumentError): - sess.run(get_next) - - def _structuredRaggedSparseDataset(self, structure, shapes, dtype): - - def map_fn(shape): - dense_to_sparse = self._make_dense_to_sparse_fn(False) - return dense_to_sparse(array_ops.zeros(shape, dtype=dtype)) - - if structure is None: - return dataset_ops.Dataset.from_tensor_slices(shapes).map(map_fn) - else: - return dataset_ops.Dataset.zip( - tuple([ - self._structuredRaggedSparseDataset(substructure, shapes, dtype) - for substructure in structure - ])) - - def _structuredRaggedSparseElement(self, structure, shapes, dtype, - padded_shape): - if structure is None: - dense_shape = np.maximum(np.amax(shapes, axis=0), padded_shape) - values = [] - for shape in shapes: - dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test - sparse = dense_to_sparse(array_ops.zeros(shape, dtype=dtype)) - padded_sparse = sparse_tensor.SparseTensor(sparse.indices, - sparse.values, dense_shape) - reshaped_sparse = sparse_ops.sparse_reshape( - padded_sparse, - array_ops.concat([np.array([1], dtype=np.int64), dense_shape], 0)) - values.append(reshaped_sparse) - return sparse_ops.sparse_concat(0, values) - else: - return tuple([ - self._structuredRaggedSparseElement(substructure, shapes, dtype, - padded_shape) - for substructure in structure - ]) - - @parameterized.named_parameters( - ("1", None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]), - ("2", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]), - ("3", None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]), - ("4", None, np.int64([[1], [2], [3]]), dtypes.string, [-1]), - ("5", None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]), - ("6", None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]), - ("7", (None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]), - ("8", (None, - (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]), - ("9", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]), - ("10", None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])), - ) - def testWindowDatasetPaddedBatchSparse(self, structure, shapes, dtype, - padded_shape): - """Tests padded batching of sparse tensor windows. - - Args: - structure: the input structure - shapes: the input shapes - dtype: the input data type - padded_shape: the shape to pad the output to - """ - - def fn(*args): - if len(args) == 1 and not isinstance(args[0], tuple): - return batching.padded_batch_window(args[0], padded_shape) - - return tuple([ - fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window( - arg, padded_shape) for arg in args - ]) - - dataset = self._structuredRaggedSparseDataset( - structure, shapes, dtype).apply(grouping.window_dataset( - len(shapes))).apply(grouping._map_x_dataset(fn)) - get_next = dataset.make_one_shot_iterator().get_next() - with self.cached_session() as sess: - expected = sess.run( - self._structuredRaggedSparseElement(structure, shapes, dtype, - padded_shape)) - actual = sess.run(get_next) - self._assertEqual(expected, actual) - - @parameterized.named_parameters( - ("1", np.int64([[1], [2], [3]]), [-1]), - ("2", np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]), - ("3", np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]), - ) - def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes, - padded_shape): - """Tests padded batching of dynamically shaped sparse tensor windows. - - Args: - shapes: the input shapes - padded_shape: the shape to pad the output to - """ - - shapes_t = array_ops.placeholder(dtypes.int32) - dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map( - lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map( - self._make_dense_to_sparse_fn(False) - ).apply(grouping.window_dataset(len(shapes))).apply( - grouping._map_x_dataset( - lambda x: batching.padded_batch_window(x, padded_shape))) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - with self.cached_session() as sess: - sess.run(init_op, {shapes_t: shapes}) - expected = sess.run( - self._structuredRaggedSparseElement(None, shapes, dtypes.int32, - padded_shape)) - actual = sess.run(get_next) - self._assertEqual(expected, actual) - - @parameterized.named_parameters( - ("1", np.int64([[1]]), [0]), - ("2", np.int64([[10], [20]]), [15]), - ) - def testWindowDatasetPaddedBatchSparseInvalid(self, shapes, padded_shape): - """Tests invalid padded batching of sparse tensor windows. - - Args: - shapes: the input shapes - padded_shape: the shape to pad the output to - """ - - dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map( - lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map( - self._make_dense_to_sparse_fn(False) - ).apply(grouping.window_dataset(len(shapes))).apply( - grouping._map_x_dataset( - lambda x: batching.padded_batch_window(x, padded_shape))) - get_next = dataset.make_one_shot_iterator().get_next() - with self.cached_session() as sess: - with self.assertRaises(errors.InvalidArgumentError): - sess.run(get_next) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 5cd1ed542bf0795b6c3756dd5b40ef739816af60..34dc2379d0cb38f8f6962fa42efe21b793bc8d65 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -16,10 +16,7 @@ py_library( srcs = ["counter.py"], srcs_version = "PY2AND3", deps = [ - ":scan_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/experimental/ops:counter", ], ) @@ -28,12 +25,7 @@ py_library( srcs = ["get_single_element.py"], srcs_version = "PY2AND3", deps = [ - ":grouping", - "//tensorflow/python:dataset_ops_gen", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", - "//third_party/py/numpy", + "//tensorflow/python/data/experimental/ops:get_single_element", ], ) @@ -44,10 +36,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:dataset_ops_gen", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", - "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/experimental/ops:iterator_ops", ], ) @@ -58,15 +47,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:constant_op", - "//tensorflow/python:dataset_ops_gen", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:random_seed", - "//tensorflow/python:tensor_shape", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", + "//tensorflow/python/data/experimental/ops:random_ops", ], ) @@ -79,7 +60,6 @@ py_library( deps = [ ":batching", ":interleave_ops", - ":optimization", ":parsing_ops", ":shuffle_ops", "//tensorflow/python:constant_op", @@ -91,6 +71,7 @@ py_library( "//tensorflow/python:platform", "//tensorflow/python:tensor_shape", "//tensorflow/python:util", + "//tensorflow/python/data/experimental/ops:readers", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:readers", "//tensorflow/python/data/util:convert", @@ -106,7 +87,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/experimental/ops:shuffle_ops", ], ) @@ -125,6 +106,7 @@ py_library( "//tensorflow/python:math_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", + "//tensorflow/python/data/experimental/ops:batching", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:convert", "//tensorflow/python/data/util:nest", @@ -138,8 +120,7 @@ py_library( srcs = ["enumerate_ops.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:dtypes", - "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/experimental/ops:enumerate_ops", ], ) @@ -148,10 +129,7 @@ py_library( srcs = ["error_ops.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:experimental_dataset_ops_gen", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", + "//tensorflow/python/data/experimental/ops:error_ops", ], ) @@ -160,16 +138,7 @@ py_library( srcs = ["grouping.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:function", - "//tensorflow/python:math_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", + "//tensorflow/python/data/experimental/ops:grouping", ], ) @@ -178,30 +147,7 @@ py_library( srcs = ["interleave_ops.py"], srcs_version = "PY2AND3", deps = [ - ":random_ops", - "//tensorflow/contrib/stateless", - "//tensorflow/python:array_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:experimental_dataset_ops_gen", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:util", - "//tensorflow/python/data/ops:readers", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", - ], -) - -py_library( - name = "optimization", - srcs = ["optimization.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python:dtypes", - "//tensorflow/python:experimental_dataset_ops_gen", - "//tensorflow/python:framework_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", + "//tensorflow/python/data/experimental/ops:interleave_ops", ], ) @@ -210,25 +156,7 @@ py_library( srcs = ["parsing_ops.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:dataset_ops_gen", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:tensor_shape", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - ], -) - -py_library( - name = "map_defun", - srcs = ["map_defun.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python:dataset_ops_gen", - "//tensorflow/python:framework_ops", - "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/experimental/ops:parsing_ops", ], ) @@ -237,18 +165,7 @@ py_library( srcs = ["resampling.py"], srcs_version = "PY2AND3", deps = [ - ":batching", - ":interleave_ops", - ":scan_ops", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:logging_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", + "//tensorflow/python/data/experimental/ops:resampling", ], ) @@ -257,12 +174,7 @@ py_library( srcs = ["scan_ops.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:dataset_ops_gen", - "//tensorflow/python:framework_ops", - "//tensorflow/python:function", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", + "//tensorflow/python/data/experimental/ops:scan_ops", ], ) @@ -281,32 +193,12 @@ py_library( ], ) -py_library( - name = "stats_ops", - srcs = ["stats_ops.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python:dataset_ops_gen", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", - ], -) - py_library( name = "threadpool", srcs = ["threadpool.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:experimental_dataset_ops_gen", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", - "//tensorflow/python/eager:context", + "//tensorflow/python/data/experimental/ops:threadpool", ], ) @@ -317,11 +209,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:dtypes", - "//tensorflow/python:experimental_dataset_ops_gen", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", + "//tensorflow/python/data/experimental/ops:unique", ], ) @@ -332,20 +220,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:dtypes", - "//tensorflow/python/data/ops:dataset_ops", - ], -) - -py_library( - name = "indexed_dataset_ops", - srcs = ["indexed_dataset_ops.py"], - deps = [ - "//tensorflow/python:experimental_dataset_ops_gen", - "//tensorflow/python:framework_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", + "//tensorflow/python/data/experimental/ops:writers", ], ) @@ -353,11 +228,7 @@ py_library( name = "prefetching_ops", srcs = ["prefetching_ops.py"], deps = [ - "//tensorflow/python:experimental_dataset_ops_gen", - "//tensorflow/python:framework_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", + "//tensorflow/python/data/experimental/ops:prefetching_ops", ], ) @@ -370,17 +241,14 @@ py_library( ":error_ops", ":get_single_element", ":grouping", - ":indexed_dataset_ops", ":interleave_ops", - ":map_defun", - ":optimization", ":prefetching_ops", + ":random_ops", ":readers", ":resampling", ":scan_ops", ":shuffle_ops", ":sliding", - ":stats_ops", ":threadpool", ":unique", ":writers", diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 7a0f221284385ed1de2b7389c9419f4b3da3c5b8..8c60459ca81cd7a7e08d90339011c54275ea9c0b 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -17,134 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - -from tensorflow.contrib.data.python.ops import get_single_element -from tensorflow.contrib.data.python.ops import grouping from tensorflow.contrib.framework import with_shape -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import convert +from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gen_array_ops -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import sparse_ops from tensorflow.python.util import deprecation -def batch_window(dataset): - """Batches a window of tensors. - - Args: - dataset: the input dataset. - - Returns: - A `Tensor` representing the batch of the entire input dataset. - """ - if isinstance(dataset.output_classes, tuple): - raise TypeError("Input dataset expected to have a single component") - if dataset.output_classes is ops.Tensor: - return _batch_dense_window(dataset) - elif dataset.output_classes is sparse_tensor.SparseTensor: - return _batch_sparse_window(dataset) - else: - raise TypeError("Unsupported dataset type: %s" % dataset.output_classes) - - -def _batch_dense_window(dataset): - """Batches a window of dense tensors.""" - - def key_fn(_): - return np.int64(0) - - def shape_init_fn(_): - return array_ops.shape(first_element) - - def shape_reduce_fn(state, value): - check_ops.assert_equal(state, array_ops.shape(value)) - return state - - def finalize_fn(state): - return state - - if dataset.output_shapes.is_fully_defined(): - shape = dataset.output_shapes - else: - first_element = get_single_element.get_single_element(dataset.take(1)) - shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn, - finalize_fn) - shape = get_single_element.get_single_element( - dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer))) - - def batch_init_fn(_): - batch_shape = array_ops.concat([[0], shape], 0) - return gen_array_ops.empty(batch_shape, dtype=dataset.output_types) - - def batch_reduce_fn(state, value): - return array_ops.concat([state, [value]], 0) - - batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) - return get_single_element.get_single_element( - dataset.apply(grouping.group_by_reducer(key_fn, batch_reducer))) - - -def _batch_sparse_window(dataset): - """Batches a window of sparse tensors.""" - - def key_fn(_): - return np.int64(0) - - def shape_init_fn(_): - return first_element.dense_shape - - def shape_reduce_fn(state, value): - check_ops.assert_equal(state, value.dense_shape) - return state - - def finalize_fn(state): - return state - - if dataset.output_shapes.is_fully_defined(): - shape = dataset.output_shapes - else: - first_element = get_single_element.get_single_element(dataset.take(1)) - shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn, - finalize_fn) - shape = get_single_element.get_single_element( - dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer))) - - def batch_init_fn(_): - indices_shape = array_ops.concat([[0], [array_ops.size(shape) + 1]], 0) - return sparse_tensor.SparseTensor( - indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64), - values=constant_op.constant([], shape=[0], dtype=dataset.output_types), - dense_shape=array_ops.concat( - [np.array([0], dtype=np.int64), - math_ops.cast(shape, dtypes.int64)], 0)) - - def batch_reduce_fn(state, value): - return sparse_ops.sparse_concat(0, [state, value]) - - def reshape_fn(value): - return sparse_ops.sparse_reshape( - value, - array_ops.concat([np.array([1], dtype=np.int64), value.dense_shape], 0)) - - batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) - return get_single_element.get_single_element( - dataset.map(reshape_fn).apply( - grouping.group_by_reducer(key_fn, batch_reducer))) - - +@deprecation.deprecated( + None, "Use `tf.data.experimental.dense_to_sparse_batch(...)`.") def dense_to_sparse_batch(batch_size, row_shape): """A transformation that batches ragged elements into `tf.SparseTensor`s. @@ -187,201 +67,10 @@ def dense_to_sparse_batch(batch_size, row_shape): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - - def _apply_fn(dataset): - return _DenseToSparseBatchDataset(dataset, batch_size, row_shape) - - return _apply_fn - - -def padded_batch_window(dataset, padded_shape, padding_value=None): - """Batches a window of tensors with padding. - - Args: - dataset: the input dataset. - padded_shape: (Optional.) `tf.TensorShape` or `tf.int64` vector tensor-like - object representing the shape to which the input elements should be padded - prior to batching. Any unknown dimensions (e.g. `tf.Dimension(None)` in a - `tf.TensorShape` or `-1` in a tensor-like object) will be padded to the - maximum size of that dimension in each batch. - padding_value: (Optional.) A scalar-shaped `tf.Tensor`, representing the - padding value to use. Defaults are `0` for numeric types and the empty - string for string types. If `dataset` contains `tf.SparseTensor`, this - value is ignored. - - Returns: - A `Tensor` representing the batch of the entire input dataset. - - Raises: - ValueError: if invalid arguments are provided. - """ - if not issubclass(dataset.output_classes, - (ops.Tensor, sparse_tensor.SparseTensor)): - raise TypeError("Input dataset expected to have a single tensor component") - if issubclass(dataset.output_classes, (ops.Tensor)): - return _padded_batch_dense_window(dataset, padded_shape, padding_value) - elif issubclass(dataset.output_classes, (sparse_tensor.SparseTensor)): - if padding_value is not None: - raise ValueError("Padding value not allowed for sparse tensors") - return _padded_batch_sparse_window(dataset, padded_shape) - else: - raise TypeError("Unsupported dataset type: %s" % dataset.output_classes) - - -def _padded_batch_dense_window(dataset, padded_shape, padding_value=None): - """Batches a window of dense tensors with padding.""" - - padded_shape = math_ops.cast( - convert.partial_shape_to_tensor(padded_shape), dtypes.int32) - - def key_fn(_): - return np.int64(0) - - def max_init_fn(_): - return padded_shape - - def max_reduce_fn(state, value): - """Computes the maximum shape to pad to.""" - condition = math_ops.reduce_all( - math_ops.logical_or( - math_ops.less_equal(array_ops.shape(value), padded_shape), - math_ops.equal(padded_shape, -1))) - assert_op = control_flow_ops.Assert(condition, [ - "Actual shape greater than padded shape: ", - array_ops.shape(value), padded_shape - ]) - with ops.control_dependencies([assert_op]): - return math_ops.maximum(state, array_ops.shape(value)) - - def finalize_fn(state): - return state - - # Compute the padded shape. - max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn) - padded_shape = get_single_element.get_single_element( - dataset.apply(grouping.group_by_reducer(key_fn, max_reducer))) - - if padding_value is None: - if dataset.output_types == dtypes.string: - padding_value = "" - elif dataset.output_types == dtypes.bool: - padding_value = False - elif dataset.output_types == dtypes.variant: - raise TypeError("Unable to create padding for field of type 'variant'") - else: - padding_value = 0 - - def batch_init_fn(_): - batch_shape = array_ops.concat( - [np.array([0], dtype=np.int32), padded_shape], 0) - return gen_array_ops.empty(batch_shape, dtype=dataset.output_types) - - def batch_reduce_fn(state, value): - return array_ops.concat([state, [value]], 0) - - def pad_fn(value): - shape = array_ops.shape(value) - left = array_ops.zeros_like(shape) - right = padded_shape - shape - return array_ops.pad( - value, array_ops.stack([left, right], 1), constant_values=padding_value) - - batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) - return get_single_element.get_single_element( - dataset.map(pad_fn).apply( - grouping.group_by_reducer(key_fn, batch_reducer))) - - -def _padded_batch_sparse_window(dataset, padded_shape): - """Batches a window of sparse tensors with padding.""" - - def key_fn(_): - return np.int64(0) - - def max_init_fn(_): - return convert.partial_shape_to_tensor(padded_shape) - - def max_reduce_fn(state, value): - """Computes the maximum shape to pad to.""" - condition = math_ops.reduce_all( - math_ops.logical_or( - math_ops.less_equal(value.dense_shape, padded_shape), - math_ops.equal(padded_shape, -1))) - assert_op = control_flow_ops.Assert(condition, [ - "Actual shape greater than padded shape: ", value.dense_shape, - padded_shape - ]) - with ops.control_dependencies([assert_op]): - return math_ops.maximum(state, value.dense_shape) - - def finalize_fn(state): - return state - - # Compute the padded shape. - max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn) - padded_shape = get_single_element.get_single_element( - dataset.apply(grouping.group_by_reducer(key_fn, max_reducer))) - - def batch_init_fn(_): - indices_shape = array_ops.concat([[0], [array_ops.size(padded_shape) + 1]], - 0) - return sparse_tensor.SparseTensor( - indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64), - values=constant_op.constant([], shape=[0], dtype=dataset.output_types), - dense_shape=array_ops.concat( - [np.array([0], dtype=np.int64), padded_shape], 0)) - - def batch_reduce_fn(state, value): - padded_value = sparse_tensor.SparseTensor( - indices=value.indices, values=value.values, dense_shape=padded_shape) - reshaped_value = sparse_ops.sparse_reshape( - padded_value, - array_ops.concat( - [np.array([1], dtype=np.int64), padded_value.dense_shape], 0)) - return sparse_ops.sparse_concat(0, [state, reshaped_value]) - - reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) - return get_single_element.get_single_element( - dataset.apply(grouping.group_by_reducer(key_fn, reducer))) - - -class _UnbatchDataset(dataset_ops.UnaryDataset): - """A dataset that splits the elements of its input into multiple elements.""" - - def __init__(self, input_dataset): - """See `unbatch()` for more details.""" - super(_UnbatchDataset, self).__init__(input_dataset) - flat_shapes = nest.flatten(input_dataset.output_shapes) - if any(s.ndims == 0 for s in flat_shapes): - raise ValueError("Cannot unbatch an input with scalar components.") - known_batch_dim = tensor_shape.Dimension(None) - for s in flat_shapes: - try: - known_batch_dim = known_batch_dim.merge_with(s[0]) - except ValueError: - raise ValueError("Cannot unbatch an input whose components have " - "different batch sizes.") - self._input_dataset = input_dataset - - def _as_variant_tensor(self): - return gen_dataset_ops.unbatch_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - **dataset_ops.flat_structure(self)) - - @property - def output_classes(self): - return self._input_dataset.output_classes - - @property - def output_shapes(self): - return nest.map_structure(lambda s: s[1:], - self._input_dataset.output_shapes) - - @property - def output_types(self): - return self._input_dataset.output_types + return batching.dense_to_sparse_batch(batch_size, row_shape) +@deprecation.deprecated(None, "Use `tf.data.experimental.unbatch()`.") def unbatch(): """Splits elements of a dataset into multiple elements on the batch dimension. @@ -403,39 +92,7 @@ def unbatch(): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - - def _apply_fn(dataset): - """Function from `Dataset` to `Dataset` that applies the transformation.""" - if not sparse.any_sparse(dataset.output_classes): - return _UnbatchDataset(dataset) - - # NOTE(mrry): We must ensure that any SparseTensors in `dataset` - # are normalized to the rank-1 dense representation, so that the - # sparse-oblivious unbatching logic will slice them - # appropriately. This leads to a somewhat inefficient re-encoding step - # for all SparseTensor components. - # TODO(mrry): Consider optimizing this in future - # if it turns out to be a bottleneck. - def normalize(arg, *rest): - if rest: - return sparse.serialize_many_sparse_tensors((arg,) + rest) - else: - return sparse.serialize_many_sparse_tensors(arg) - - normalized_dataset = dataset.map(normalize) - - # NOTE(mrry): Our `map()` has lost information about the sparseness - # of any SparseTensor components, so re-apply the structure of the - # original dataset. - restructured_dataset = _RestructuredDataset( - normalized_dataset, - dataset.output_types, - dataset.output_shapes, - dataset.output_classes, - allow_unsafe_cast=True) - return _UnbatchDataset(restructured_dataset) - - return _apply_fn + return batching.unbatch() @deprecation.deprecated( @@ -514,135 +171,8 @@ def padded_batch_and_drop_remainder(batch_size, return _apply_fn -class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset): - """A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s.""" - - def __init__(self, input_dataset, batch_size, row_shape): - """See `Dataset.dense_to_sparse_batch()` for more details.""" - super(_DenseToSparseBatchDataset, self).__init__(input_dataset) - if not isinstance(input_dataset.output_types, dtypes.DType): - raise TypeError("DenseToSparseDataset requires an input whose elements " - "have a single component, whereas the input has %r." % - input_dataset.output_types) - self._input_dataset = input_dataset - self._batch_size = batch_size - self._row_shape = row_shape - - def _as_variant_tensor(self): - return gen_dataset_ops.dense_to_sparse_batch_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - self._batch_size, - row_shape=convert.partial_shape_to_tensor(self._row_shape), - **dataset_ops.flat_structure(self)) - - @property - def output_classes(self): - return sparse_tensor.SparseTensor - - @property - def output_shapes(self): - return tensor_shape.vector(None).concatenate(self._row_shape) - - @property - def output_types(self): - return self._input_dataset.output_types - - -class _RestructuredDataset(dataset_ops.UnaryDataset): - """An internal helper for changing the structure and shape of a dataset.""" - - def __init__(self, - dataset, - output_types, - output_shapes=None, - output_classes=None, - allow_unsafe_cast=False): - """Creates a new dataset with the given output types and shapes. - - The given `dataset` must have a structure that is convertible: - * `dataset.output_types` must be the same as `output_types` module nesting. - * Each shape in `dataset.output_shapes` must be compatible with each shape - in `output_shapes` (if given). - - Note: This helper permits "unsafe casts" for shapes, equivalent to using - `tf.Tensor.set_shape()` where domain-specific knowledge is available. - - Args: - dataset: A `Dataset` object. - output_types: A nested structure of `tf.DType` objects. - output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects. - If omitted, the shapes will be inherited from `dataset`. - output_classes: (Optional.) A nested structure of class types. - If omitted, the class types will be inherited from `dataset`. - allow_unsafe_cast: (Optional.) If `True`, the caller may switch the - reported output types and shapes of the restructured dataset, e.g. to - switch a sparse tensor represented as `tf.variant` to its user-visible - type and shape. - - Raises: - ValueError: If either `output_types` or `output_shapes` is not compatible - with the structure of `dataset`. - """ - super(_RestructuredDataset, self).__init__(dataset) - self._input_dataset = dataset - - if not allow_unsafe_cast: - # Validate that the types are compatible. - output_types = nest.map_structure(dtypes.as_dtype, output_types) - flat_original_types = nest.flatten(dataset.output_types) - flat_new_types = nest.flatten(output_types) - if flat_original_types != flat_new_types: - raise ValueError( - "Dataset with output types %r cannot be restructured to have " - "output types %r" % (dataset.output_types, output_types)) - - self._output_types = output_types - - if output_shapes is None: - # Inherit shapes from the original `dataset`. - self._output_shapes = nest.pack_sequence_as(output_types, - nest.flatten( - dataset.output_shapes)) - else: - if not allow_unsafe_cast: - # Validate that the shapes are compatible. - nest.assert_same_structure(output_types, output_shapes) - flat_original_shapes = nest.flatten(dataset.output_shapes) - flat_new_shapes = nest.flatten_up_to(output_types, output_shapes) - - for original_shape, new_shape in zip(flat_original_shapes, - flat_new_shapes): - if not original_shape.is_compatible_with(new_shape): - raise ValueError( - "Dataset with output shapes %r cannot be restructured to have " - "incompatible output shapes %r" % (dataset.output_shapes, - output_shapes)) - self._output_shapes = nest.map_structure_up_to( - output_types, tensor_shape.as_shape, output_shapes) - if output_classes is None: - # Inherit class types from the original `dataset`. - self._output_classes = nest.pack_sequence_as(output_types, - nest.flatten( - dataset.output_classes)) - else: - self._output_classes = output_classes - - def _as_variant_tensor(self): - return self._input_dataset._as_variant_tensor() # pylint: disable=protected-access - - @property - def output_classes(self): - return self._output_classes - - @property - def output_types(self): - return self._output_types - - @property - def output_shapes(self): - return self._output_shapes - - +# TODO(b/116817045): Move this to `tf.data.experimental` when the `with_shape()` +# function is available in the core. def assert_element_shape(expected_shapes): """Assert the shape of this `Dataset`. @@ -687,7 +217,8 @@ def assert_element_shape(expected_shapes): def _apply_fn(dataset): output_shapes = _merge_output_shapes(dataset.output_shapes, expected_shapes) - return _RestructuredDataset( + # pylint: disable=protected-access + return batching._RestructuredDataset( dataset.map(_check_shape), dataset.output_types, output_shapes=output_shapes, @@ -696,49 +227,7 @@ def assert_element_shape(expected_shapes): return _apply_fn -class _MapAndBatchDataset(dataset_ops.MapDataset): - """A `Dataset` that maps a function over a batch of elements.""" - - def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls, - drop_remainder): - """See `Dataset.map()` for details.""" - super(_MapAndBatchDataset, self).__init__(input_dataset, map_func) - self._batch_size_t = ops.convert_to_tensor( - batch_size, dtype=dtypes.int64, name="batch_size") - self._num_parallel_calls_t = ops.convert_to_tensor( - num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") - self._drop_remainder_t = ops.convert_to_tensor( - drop_remainder, dtype=dtypes.bool, name="drop_remainder") - - self._batch_size = batch_size - self._drop_remainder = drop_remainder - - def _as_variant_tensor(self): - # pylint: disable=protected-access - input_resource = self._input_dataset._as_variant_tensor() - return gen_dataset_ops.map_and_batch_dataset_v2( - input_resource, - self._map_func.captured_inputs, - f=self._map_func, - batch_size=self._batch_size_t, - num_parallel_calls=self._num_parallel_calls_t, - drop_remainder=self._drop_remainder_t, - **dataset_ops.flat_structure(self)) - # pylint: enable=protected-access - - @property - def output_shapes(self): - dim = self._batch_size if self._drop_remainder else None - return nest.pack_sequence_as(self._output_shapes, [ - tensor_shape.vector(dim).concatenate(s) - for s in nest.flatten(self._output_shapes) - ]) - - @property - def output_types(self): - return self._output_types - - +@deprecation.deprecated(None, "Use `tf.data.experimental.map_and_batch(...)`.") def map_and_batch(map_func, batch_size, num_parallel_batches=None, @@ -779,17 +268,5 @@ def map_and_batch(map_func, ValueError: If both `num_parallel_batches` and `num_parallel_calls` are specified. """ - - if num_parallel_batches is None and num_parallel_calls is None: - num_parallel_calls = batch_size - elif num_parallel_batches is not None and num_parallel_calls is None: - num_parallel_calls = batch_size * num_parallel_batches - elif num_parallel_batches is not None and num_parallel_calls is not None: - raise ValueError("The `num_parallel_batches` and `num_parallel_calls` " - "arguments are mutually exclusive.") - - def _apply_fn(dataset): - return _MapAndBatchDataset(dataset, map_func, batch_size, - num_parallel_calls, drop_remainder) - - return _apply_fn + return batching.map_and_batch(map_func, batch_size, num_parallel_batches, + drop_remainder, num_parallel_calls) diff --git a/tensorflow/contrib/data/python/ops/counter.py b/tensorflow/contrib/data/python/ops/counter.py index 6ef65f9624601286691505a795a86dd6226eead1..4ff5bf3e39dc2c9313b7d47d1ef965ebb22afc06 100644 --- a/tensorflow/contrib/data/python/ops/counter.py +++ b/tensorflow/contrib/data/python/ops/counter.py @@ -17,13 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import scan_ops - -from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.experimental.ops import counter from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation +@deprecation.deprecated(None, "Use `tf.data.experimental.Counter(...)`.") def Counter(start=0, step=1, dtype=dtypes.int64): """Creates a `Dataset` that counts from `start` in steps of size `step`. @@ -46,8 +45,4 @@ def Counter(start=0, step=1, dtype=dtypes.int64): Returns: A `Dataset` of scalar `dtype` elements. """ - with ops.name_scope("counter"): - start = ops.convert_to_tensor(start, dtype=dtype, name="start") - step = ops.convert_to_tensor(step, dtype=dtype, name="step") - return dataset_ops.Dataset.from_tensors(0).repeat(None).apply( - scan_ops.scan(start, lambda state, _: (state + step, state))) + return counter.Counter(start, step, dtype) diff --git a/tensorflow/contrib/data/python/ops/enumerate_ops.py b/tensorflow/contrib/data/python/ops/enumerate_ops.py index 490281e0d2da7a454a2f63f95753c7c436b87a76..a21da4d3eca508f2af9bac49d57fb0c4b08f3be0 100644 --- a/tensorflow/contrib/data/python/ops/enumerate_ops.py +++ b/tensorflow/contrib/data/python/ops/enumerate_ops.py @@ -17,12 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import dtypes +from tensorflow.python.data.experimental.ops import enumerate_ops +from tensorflow.python.util import deprecation +@deprecation.deprecated(None, + "Use `tf.data.experimental.enumerate_dataset(...)`.") def enumerate_dataset(start=0): """A transformation that enumerate the elements of a dataset. @@ -49,10 +50,4 @@ def enumerate_dataset(start=0): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - - def _apply_fn(dataset): - max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max - return dataset_ops.Dataset.zip((dataset_ops.Dataset.range(start, max_value), - dataset)) - - return _apply_fn + return enumerate_ops.enumerate_dataset(start) diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index f962e623ee7a708912eb6d60b0e7c8613a975308..0559a2e09cce43cf16e88dbe20dba2c46db4c979 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -17,10 +17,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.ops import gen_experimental_dataset_ops +from tensorflow.python.data.experimental.ops import error_ops +from tensorflow.python.util import deprecation +@deprecation.deprecated(None, "Use `tf.data.experimental.ignore_errors()`.") def ignore_errors(): """Creates a `Dataset` from another `Dataset` and silently ignores any errors. @@ -43,34 +44,4 @@ def ignore_errors(): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - - def _apply_fn(dataset): - return _IgnoreErrorsDataset(dataset) - - return _apply_fn - - -class _IgnoreErrorsDataset(dataset_ops.UnaryDataset): - """A `Dataset` that silently ignores errors when computing its input.""" - - def __init__(self, input_dataset): - """See `Dataset.ignore_errors()` for details.""" - super(_IgnoreErrorsDataset, self).__init__(input_dataset) - self._input_dataset = input_dataset - - def _as_variant_tensor(self): - return gen_experimental_dataset_ops.experimental_ignore_errors_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - **dataset_ops.flat_structure(self)) - - @property - def output_classes(self): - return self._input_dataset.output_classes - - @property - def output_shapes(self): - return self._input_dataset.output_shapes - - @property - def output_types(self): - return self._input_dataset.output_types + return error_ops.ignore_errors() diff --git a/tensorflow/contrib/data/python/ops/get_single_element.py b/tensorflow/contrib/data/python/ops/get_single_element.py index a6713b017afa315edec9389d0a6c1c7135e6aeb9..58ad9eea903c42981b8fd083ed1c39421c58189f 100644 --- a/tensorflow/contrib/data/python/ops/get_single_element.py +++ b/tensorflow/contrib/data/python/ops/get_single_element.py @@ -19,13 +19,13 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.ops import grouping +from tensorflow.python.data.experimental.ops import get_single_element as experimental_get_single_element from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse -from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.util import deprecation +@deprecation.deprecated(None, + "Use `tf.data.experimental.get_single_element(...)`.") def get_single_element(dataset): """Returns the single element in `dataset` as a nested structure of tensors. @@ -61,18 +61,10 @@ def get_single_element(dataset): InvalidArgumentError (at runtime): if `dataset` does not contain exactly one element. """ - if not isinstance(dataset, dataset_ops.Dataset): - raise TypeError("`dataset` must be a `tf.data.Dataset` object.") - - nested_ret = nest.pack_sequence_as( - dataset.output_types, gen_dataset_ops.dataset_to_single_element( - dataset._as_variant_tensor(), # pylint: disable=protected-access - **dataset_ops.flat_structure(dataset))) - return sparse.deserialize_sparse_tensors( - nested_ret, dataset.output_types, dataset.output_shapes, - dataset.output_classes) + return experimental_get_single_element.get_single_element(dataset) +@deprecation.deprecated(None, "Use `tf.data.Dataset.reduce(...)`.") def reduce_dataset(dataset, reducer): """Returns the result of reducing the `dataset` using `reducer`. @@ -90,11 +82,4 @@ def reduce_dataset(dataset, reducer): if not isinstance(dataset, dataset_ops.Dataset): raise TypeError("`dataset` must be a `tf.data.Dataset` object.") - # The sentinel dataset is used in case the reduced dataset is empty. - sentinel_dataset = dataset_ops.Dataset.from_tensors( - reducer.finalize_func(reducer.init_func(np.int64(0)))) - reduced_dataset = dataset.apply( - grouping.group_by_reducer(lambda x: np.int64(0), reducer)) - - return get_single_element( - reduced_dataset.concatenate(sentinel_dataset).take(1)) + return dataset.reduce(reducer.init_func(np.int64(0)), reducer.reduce_func) diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index 7cae33beb35e0280c0e3cc0e3f74eaab7f76dac2..a99dc2f29ae4c9d47c21afd83f49bf4eb89eca18 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -17,20 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import math_ops +from tensorflow.python.data.experimental.ops import grouping +from tensorflow.python.util import deprecation +@deprecation.deprecated(None, + "Use `tf.data.experimental.group_by_reducer(...)`.") def group_by_reducer(key_func, reducer): """A transformation that groups elements and performs a reduction. @@ -52,14 +45,11 @@ def group_by_reducer(key_func, reducer): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - - def _apply_fn(dataset): - """Function from `Dataset` to `Dataset` that applies the transformation.""" - return _GroupByReducerDataset(dataset, key_func, reducer) - - return _apply_fn + return grouping.group_by_reducer(key_func, reducer) +@deprecation.deprecated(None, + "Use `tf.data.experimental.group_by_window(...)`.") def group_by_window(key_func, reduce_func, window_size=None, @@ -98,27 +88,12 @@ def group_by_window(key_func, ValueError: if neither or both of {`window_size`, `window_size_func`} are passed. """ - if (window_size is not None and window_size_func or - not (window_size is not None or window_size_func)): - raise ValueError("Must pass either window_size or window_size_func.") - - if window_size is not None: - - def constant_window_func(unused_key): - return ops.convert_to_tensor(window_size, dtype=dtypes.int64) - - window_size_func = constant_window_func - - assert window_size_func is not None - - def _apply_fn(dataset): - """Function from `Dataset` to `Dataset` that applies the transformation.""" - return _GroupByWindowDataset(dataset, key_func, reduce_func, - window_size_func) - - return _apply_fn + return grouping.group_by_window(key_func, reduce_func, window_size, + window_size_func) +@deprecation.deprecated( + None, "Use `tf.data.experimental.bucket_by_sequence_length(...)`.") def bucket_by_sequence_length(element_length_func, bucket_boundaries, bucket_batch_sizes, @@ -163,342 +138,12 @@ def bucket_by_sequence_length(element_length_func, Raises: ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`. """ - with ops.name_scope("bucket_by_seq_length"): - if len(bucket_batch_sizes) != (len(bucket_boundaries) + 1): - raise ValueError( - "len(bucket_batch_sizes) must equal len(bucket_boundaries) + 1") - - batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64) - - def element_to_bucket_id(*args): - """Return int64 id of the length bucket for this element.""" - seq_length = element_length_func(*args) - - boundaries = list(bucket_boundaries) - buckets_min = [np.iinfo(np.int32).min] + boundaries - buckets_max = boundaries + [np.iinfo(np.int32).max] - conditions_c = math_ops.logical_and( - math_ops.less_equal(buckets_min, seq_length), - math_ops.less(seq_length, buckets_max)) - bucket_id = math_ops.reduce_min(array_ops.where(conditions_c)) - - return bucket_id - - def window_size_fn(bucket_id): - # The window size is set to the batch size for this bucket - window_size = batch_sizes[bucket_id] - return window_size - - def make_padded_shapes(shapes, none_filler=None): - padded = [] - for shape in nest.flatten(shapes): - shape = tensor_shape.TensorShape(shape) - shape = [ - none_filler if d.value is None else d - for d in shape - ] - padded.append(shape) - return nest.pack_sequence_as(shapes, padded) - - def batching_fn(bucket_id, grouped_dataset): - """Batch elements in dataset.""" - batch_size = window_size_fn(bucket_id) - if no_padding: - return grouped_dataset.batch(batch_size) - none_filler = None - if pad_to_bucket_boundary: - err_msg = ("When pad_to_bucket_boundary=True, elements must have " - "length < max(bucket_boundaries).") - check = check_ops.assert_less( - bucket_id, - constant_op.constant(len(bucket_batch_sizes) - 1, - dtype=dtypes.int64), - message=err_msg) - with ops.control_dependencies([check]): - boundaries = constant_op.constant(bucket_boundaries, - dtype=dtypes.int64) - bucket_boundary = boundaries[bucket_id] - none_filler = bucket_boundary - 1 - shapes = make_padded_shapes( - padded_shapes or grouped_dataset.output_shapes, - none_filler=none_filler) - return grouped_dataset.padded_batch(batch_size, shapes, padding_values) - - def _apply_fn(dataset): - return dataset.apply( - group_by_window(element_to_bucket_id, batching_fn, - window_size_func=window_size_fn)) - - return _apply_fn - - -def _map_x_dataset(map_func): - """A transformation that maps `map_func` across its input. - - This transformation is similar to `tf.data.Dataset.map`, but in addition to - supporting dense and sparse tensor inputs, it also supports dataset inputs. - - Args: - map_func: A function mapping a nested structure of tensors and/or datasets - (having shapes and types defined by `self.output_shapes` and - `self.output_types`) to another nested structure of tensors and/or - datasets. - - Returns: - Dataset: A `Dataset`. - """ - - def _apply_fn(dataset): - """Function from `Dataset` to `Dataset` that applies the transformation.""" - return _MapXDataset(dataset, map_func) - - return _apply_fn - - -# TODO(b/115382007) Remove this once canned reducers move to core. -def window_dataset(window_size): - """A transformation that creates window datasets from the input dataset. - - The resulting datasets will contain `window_size` elements (or - `N % window_size` for the last dataset if `window_size` does not divide the - number of input elements `N` evenly). - - Args: - window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of - consecutive elements of the input dataset to combine into a window. - - Returns: - Dataset: A `Dataset`. - """ - - def _apply_fn(dataset): - return dataset_ops.WindowDataset( - dataset, - size=window_size, - shift=window_size, - stride=1, - drop_remainder=False) - - return _apply_fn - - -class _GroupByReducerDataset(dataset_ops.UnaryDataset): - """A `Dataset` that groups its input and performs a reduction.""" - - def __init__(self, input_dataset, key_func, reducer): - """See `group_by_reducer()` for details.""" - super(_GroupByReducerDataset, self).__init__(input_dataset) + return grouping.bucket_by_sequence_length( + element_length_func, bucket_boundaries, bucket_batch_sizes, padded_shapes, + padding_values, pad_to_bucket_boundary, no_padding) - self._input_dataset = input_dataset - self._make_key_func(key_func, input_dataset) - self._make_init_func(reducer.init_func) - self._make_reduce_func(reducer.reduce_func, input_dataset) - self._make_finalize_func(reducer.finalize_func) - - def _make_key_func(self, key_func, input_dataset): - """Make wrapping Defun for key_func.""" - wrapped_func = dataset_ops.StructuredFunctionWrapper( - key_func, "tf.contrib.data.group_by_reducer()", input_dataset) - if not ( - wrapped_func.output_types == dtypes.int64 and - wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): - raise ValueError( - "`key_func` must return a single tf.int64 tensor. " - "Got type=%s and shape=%s" - % (wrapped_func.output_types, wrapped_func.output_shapes)) - self._key_func = wrapped_func.function - - def _make_init_func(self, init_func): - """Make wrapping Defun for init_func.""" - wrapped_func = dataset_ops.StructuredFunctionWrapper( - init_func, "tf.contrib.data.group_by_reducer()", - input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(), - input_types=dtypes.int64) - self._init_func = wrapped_func.function - self._state_classes = wrapped_func.output_classes - self._state_shapes = wrapped_func.output_shapes - self._state_types = wrapped_func.output_types - - def _make_reduce_func(self, reduce_func, input_dataset): - """Make wrapping Defun for reduce_func.""" - - # Iteratively rerun the reduce function until reaching a fixed point on - # `self._state_shapes`. - need_to_rerun = True - while need_to_rerun: - - wrapped_func = dataset_ops.StructuredFunctionWrapper( - reduce_func, "tf.contrib.data.group_by_reducer()", - input_classes=(self._state_classes, input_dataset.output_classes), - input_shapes=(self._state_shapes, input_dataset.output_shapes), - input_types=(self._state_types, input_dataset.output_types), - add_to_graph=False) - - # Extract and validate class information from the returned values. - for new_state_class, state_class in zip( - nest.flatten(wrapped_func.output_classes), - nest.flatten(self._state_classes)): - if not issubclass(new_state_class, state_class): - raise TypeError( - "The element classes for the new state must match the initial " - "state. Expected %s; got %s." % - (self._state_classes, wrapped_func.output_classes)) - - # Extract and validate type information from the returned values. - for new_state_type, state_type in zip( - nest.flatten(wrapped_func.output_types), - nest.flatten(self._state_types)): - if new_state_type != state_type: - raise TypeError( - "The element types for the new state must match the initial " - "state. Expected %s; got %s." % - (self._state_types, wrapped_func.output_types)) - - # Extract shape information from the returned values. - flat_state_shapes = nest.flatten(self._state_shapes) - flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes) - weakened_state_shapes = [ - original.most_specific_compatible_shape(new) - for original, new in zip(flat_state_shapes, flat_new_state_shapes) - ] - - need_to_rerun = False - for original_shape, weakened_shape in zip(flat_state_shapes, - weakened_state_shapes): - if original_shape.ndims is not None and ( - weakened_shape.ndims is None or - original_shape.as_list() != weakened_shape.as_list()): - need_to_rerun = True - break - - if need_to_rerun: - self._state_shapes = nest.pack_sequence_as(self._state_shapes, - weakened_state_shapes) - - self._reduce_func = wrapped_func.function - self._reduce_func.add_to_graph(ops.get_default_graph()) - - def _make_finalize_func(self, finalize_func): - """Make wrapping Defun for finalize_func.""" - wrapped_func = dataset_ops.StructuredFunctionWrapper( - finalize_func, "tf.contrib.data.group_by_reducer()", - input_classes=self._state_classes, input_shapes=self._state_shapes, - input_types=self._state_types) - self._finalize_func = wrapped_func.function - self._output_classes = wrapped_func.output_classes - self._output_shapes = wrapped_func.output_shapes - self._output_types = wrapped_func.output_types - - @property - def output_classes(self): - return self._output_classes - - @property - def output_shapes(self): - return self._output_shapes - - @property - def output_types(self): - return self._output_types - - def _as_variant_tensor(self): - return gen_dataset_ops.group_by_reducer_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - self._key_func.captured_inputs, - self._init_func.captured_inputs, - self._reduce_func.captured_inputs, - self._finalize_func.captured_inputs, - key_func=self._key_func, - init_func=self._init_func, - reduce_func=self._reduce_func, - finalize_func=self._finalize_func, - **dataset_ops.flat_structure(self)) - - -class _GroupByWindowDataset(dataset_ops.UnaryDataset): - """A `Dataset` that groups its input and performs a windowed reduction.""" - - def __init__(self, input_dataset, key_func, reduce_func, window_size_func): - """See `group_by_window()` for details.""" - super(_GroupByWindowDataset, self).__init__(input_dataset) - - self._input_dataset = input_dataset - - self._make_key_func(key_func, input_dataset) - self._make_reduce_func(reduce_func, input_dataset) - self._make_window_size_func(window_size_func) - - def _make_window_size_func(self, window_size_func): - """Make wrapping Defun for window_size_func.""" - def window_size_func_wrapper(key): - return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64) - wrapped_func = dataset_ops.StructuredFunctionWrapper( - window_size_func_wrapper, "tf.contrib.data.group_by_window()", - input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(), - input_types=dtypes.int64) - if not ( - wrapped_func.output_types == dtypes.int64 and - wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): - raise ValueError( - "`window_size_func` must return a single tf.int64 scalar tensor.") - self._window_size_func = wrapped_func.function - - def _make_key_func(self, key_func, input_dataset): - """Make wrapping Defun for key_func.""" - def key_func_wrapper(*args): - return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64) - wrapped_func = dataset_ops.StructuredFunctionWrapper( - key_func_wrapper, "tf.contrib.data.group_by_window()", input_dataset) - if not ( - wrapped_func.output_types == dtypes.int64 and - wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): - raise ValueError( - "`key_func` must return a single tf.int64 scalar tensor.") - self._key_func = wrapped_func.function - - def _make_reduce_func(self, reduce_func, input_dataset): - """Make wrapping Defun for reduce_func.""" - nested_dataset = dataset_ops._NestedDatasetComponent(input_dataset) # pylint: disable=protected-access - wrapped_func = dataset_ops.StructuredFunctionWrapper( - reduce_func, "tf.contrib.data.reduce_by_window()", - input_classes=(ops.Tensor, nested_dataset), - input_shapes=(tensor_shape.scalar(), nested_dataset), - input_types=(dtypes.int64, nested_dataset), - experimental_nested_dataset_support=True) - if not isinstance( - wrapped_func.output_classes, dataset_ops._NestedDatasetComponent): # pylint: disable=protected-access - raise TypeError("`reduce_func` must return a `Dataset` object.") - self._output_classes = wrapped_func.output_classes.output_classes - self._output_types = wrapped_func.output_types.output_types - self._output_shapes = wrapped_func.output_shapes.output_shapes - self._reduce_func = wrapped_func.function - - @property - def output_classes(self): - return self._output_classes - - @property - def output_shapes(self): - return self._output_shapes - - @property - def output_types(self): - return self._output_types - - def _as_variant_tensor(self): - return gen_dataset_ops.group_by_window_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - self._key_func.captured_inputs, - self._reduce_func.captured_inputs, - self._window_size_func.captured_inputs, - key_func=self._key_func, - reduce_func=self._reduce_func, - window_size_func=self._window_size_func, - **dataset_ops.flat_structure(self)) - - -class Reducer(object): +class Reducer(grouping.Reducer): """A reducer is used for reducing a set of elements. A reducer is represented as a tuple of the three functions: @@ -507,58 +152,6 @@ class Reducer(object): 3) finalization function: state => result """ + @deprecation.deprecated(None, "Use `tf.data.experimental.Reducer(...)`.") def __init__(self, init_func, reduce_func, finalize_func): - self._init_func = init_func - self._reduce_func = reduce_func - self._finalize_func = finalize_func - - @property - def init_func(self): - return self._init_func - - @property - def reduce_func(self): - return self._reduce_func - - @property - def finalize_func(self): - return self._finalize_func - - -class _MapXDataset(dataset_ops.UnaryDataset): - """A `Dataset` that maps a function over elements in its input.""" - - def __init__(self, input_dataset, map_func): - """See `map_x_dataset()` for details.""" - super(_MapXDataset, self).__init__(input_dataset) - self._input_dataset = input_dataset - - wrapped_func = dataset_ops.StructuredFunctionWrapper( - map_func, - "tf.contrib.data.map_x_dataset()", - input_dataset, - experimental_nested_dataset_support=True) - self._output_classes = wrapped_func.output_classes - self._output_shapes = wrapped_func.output_shapes - self._output_types = wrapped_func.output_types - self._map_func = wrapped_func.function - - def _as_variant_tensor(self): - input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access - return gen_dataset_ops.map_dataset( - input_t, - self._map_func.captured_inputs, - f=self._map_func, - **dataset_ops.flat_structure(self)) - - @property - def output_classes(self): - return self._output_classes - - @property - def output_shapes(self): - return self._output_shapes - - @property - def output_types(self): - return self._output_types + super(Reducer, self).__init__(init_func, reduce_func, finalize_func) diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 1ee9db1aa830934b6228ed8025a5230a8845ef17..f50da4d429f715418a95cf177a3f4b5d273c8844 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -17,20 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import stateless -from tensorflow.contrib.data.python.ops import random_ops -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import readers -from tensorflow.python.data.util import nest -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_experimental_dataset_ops -from tensorflow.python.ops import math_ops +from tensorflow.python.data.experimental.ops import interleave_ops from tensorflow.python.util import deprecation +@deprecation.deprecated(None, + "Use `tf.data.experimental.parallel_interleave(...)`.") def parallel_interleave(map_func, cycle_length, block_length=1, @@ -80,12 +72,9 @@ def parallel_interleave(map_func, A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - def _apply_fn(dataset): - return readers.ParallelInterleaveDataset( - dataset, map_func, cycle_length, block_length, sloppy, - buffer_output_elements, prefetch_input_elements) - - return _apply_fn + return interleave_ops.parallel_interleave( + map_func, cycle_length, block_length, sloppy, buffer_output_elements, + prefetch_input_elements) @deprecation.deprecated( @@ -139,63 +128,12 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - def _apply_fn(dataset): - return readers.ParallelInterleaveDataset( - dataset, - map_func, - cycle_length, - block_length, - sloppy=True, - buffer_output_elements=None, - prefetch_input_elements=None) - - return _apply_fn - - -class _DirectedInterleaveDataset(dataset_ops.Dataset): - """A substitute for `Dataset.interleave()` on a fixed list of datasets.""" - - def __init__(self, selector_input, data_inputs): - self._selector_input = selector_input - self._data_inputs = list(data_inputs) - - for data_input in data_inputs[1:]: - if (data_input.output_types != data_inputs[0].output_types or - data_input.output_classes != data_inputs[0].output_classes): - raise TypeError("All datasets must have the same type and class.") - - def _as_variant_tensor(self): - # pylint: disable=protected-access - return ( - gen_experimental_dataset_ops.experimental_directed_interleave_dataset( - self._selector_input._as_variant_tensor(), [ - data_input._as_variant_tensor() - for data_input in self._data_inputs - ], **dataset_ops.flat_structure(self))) - # pylint: enable=protected-access - - def _inputs(self): - return [self._selector_input] + self._data_inputs - - @property - def output_classes(self): - return self._data_inputs[0].output_classes - - @property - def output_shapes(self): - ret = self._data_inputs[0].output_shapes - for data_input in self._data_inputs[1:]: - ret = nest.pack_sequence_as(ret, [ - ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip( - nest.flatten(ret), nest.flatten(data_input.output_shapes)) - ]) - return ret - - @property - def output_types(self): - return self._data_inputs[0].output_types + return interleave_ops.parallel_interleave( + map_func, cycle_length, block_length, sloppy=True) +@deprecation.deprecated(None, + "Use `tf.data.experimental.sample_from_datasets(...)`.") def sample_from_datasets(datasets, weights=None, seed=None): """Samples elements at random from the datasets in `datasets`. @@ -219,64 +157,11 @@ def sample_from_datasets(datasets, weights=None, seed=None): ValueError: If the `weights` argument is specified and does not match the length of the `datasets` element. """ - num_datasets = len(datasets) - if not isinstance(weights, dataset_ops.Dataset): - if weights is None: - # Select inputs with uniform probability. - logits = [[1.0] * num_datasets] - - else: - # Use the given `weights` as the probability of choosing the respective - # input. - weights = ops.convert_to_tensor(weights, name="weights") - if weights.dtype not in (dtypes.float32, dtypes.float64): - raise TypeError("`weights` must be convertible to a tensor of " - "`tf.float32` or `tf.float64` elements.") - if not weights.shape.is_compatible_with([num_datasets]): - raise ValueError( - "`weights` must be a vector of length `len(datasets)`.") - - # The `stateless_multinomial()` op expects log-probabilities, as opposed - # to weights. - logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0) - - # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When it - # is a `Dataset`, it is possible that evaluating it has a side effect the - # user depends on. - if len(datasets) == 1: - return datasets[0] - - def select_dataset_constant_logits(seed): - return array_ops.squeeze( - stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1]) - - selector_input = dataset_ops.MapDataset( - random_ops.RandomDataset(seed).batch(2), - select_dataset_constant_logits, - use_inter_op_parallelism=False) - - else: - # Use each element of the given `weights` dataset as the probability of - # choosing the respective input. - - # The `stateless_multinomial()` op expects log-probabilities, as opposed to - # weights. - logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits")) - - def select_dataset_varying_logits(logits, seed): - return array_ops.squeeze( - stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1]) - - logits_and_seeds = dataset_ops.Dataset.zip( - (logits_ds, random_ops.RandomDataset(seed).batch(2))) - selector_input = dataset_ops.MapDataset( - logits_and_seeds, - select_dataset_varying_logits, - use_inter_op_parallelism=False) - - return _DirectedInterleaveDataset(selector_input, datasets) + return interleave_ops.sample_from_datasets(datasets, weights, seed) +@deprecation.deprecated(None, + "Use `tf.data.experimental.choose_from_datasets(...)`.") def choose_from_datasets(datasets, choice_dataset): """Creates a dataset that deterministically chooses elements from `datasets`. @@ -312,10 +197,4 @@ def choose_from_datasets(datasets, choice_dataset): TypeError: If the `datasets` or `choice_dataset` arguments have the wrong type. """ - if not (choice_dataset.output_types == dtypes.int64 - and choice_dataset.output_shapes.is_compatible_with( - tensor_shape.scalar()) - and choice_dataset.output_classes == ops.Tensor): - raise TypeError("`choice_dataset` must be a dataset of scalar " - "`tf.int64` tensors.") - return _DirectedInterleaveDataset(choice_dataset, datasets) + return interleave_ops.choose_from_datasets(datasets, choice_dataset) diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py index 18515e21edfe0449514ab4f21683a600eaf48910..48c325c86f74b4c922e70a33212b49196b34e357 100644 --- a/tensorflow/contrib/data/python/ops/iterator_ops.py +++ b/tensorflow/contrib/data/python/ops/iterator_ops.py @@ -16,15 +16,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.training import basic_session_run_hooks -from tensorflow.python.training import checkpoint_management -from tensorflow.python.training import saver as saver_lib -from tensorflow.python.training import session_run_hook +from tensorflow.python.data.experimental.ops import iterator_ops +from tensorflow.python.util import deprecation + +@deprecation.deprecated( + None, "Use `tf.data.experimental.make_saveable_from_iterator(...)`.") def make_saveable_from_iterator(iterator): """Returns a SaveableObject for saving/restore iterator state using Saver. @@ -60,27 +58,10 @@ def make_saveable_from_iterator(iterator): Note: Not all iterators support checkpointing yet. Attempting to save the state of an unsupported iterator will throw an error. """ - return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access - - -class _Saveable(saver_lib.BaseSaverBuilder.SaveableObject): - """SaveableObject for saving/restoring iterator state.""" + return iterator_ops.make_saveable_from_iterator(iterator) - def __init__(self, iterator_resource): - serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource) - specs = [ - saver_lib.BaseSaverBuilder.SaveSpec(serialized_iterator, "", - iterator_resource.name + "-state") - ] - super(_Saveable, self).__init__(iterator_resource, specs, - iterator_resource.name) - def restore(self, restored_tensors, unused_restored_shapes): - with ops.colocate_with(self.op): - return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0]) - - -class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): +class CheckpointInputPipelineHook(iterator_ops.CheckpointInputPipelineHook): """Checkpoints input pipeline state every N steps or seconds. This hook saves the state of the iterators in the `Graph` so that when @@ -125,135 +106,7 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): collector when building the eval graph. """ + @deprecation.deprecated( + None, "Use `tf.data.experimental.CheckpointInputPipelineHook(...)`.") def __init__(self, estimator): - """Initializes a `CheckpointInputPipelineHook`. - - Args: - estimator: Estimator. - - Raises: - ValueError: One of `save_steps` or `save_secs` should be set. - ValueError: At most one of saver or scaffold should be set. - """ - # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or - # of the form "input__.ckpt" for distributed pipelines. - # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is - # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix - # to be different to avoid conflicts with the model checkpoint. - - # pylint: disable=protected-access - checkpoint_prefix = "input" - if estimator._config.num_worker_replicas > 1: - # Distributed setting. - suffix = "_{}_{}".format(estimator._config.task_type, - estimator._config.task_id) - checkpoint_prefix += suffix - # pylint: enable=protected-access - - # We use a composition paradigm instead of inheriting from - # `CheckpointSaverHook` because `Estimator` does an `isinstance` check - # to check whether a `CheckpointSaverHook` is already present in the list - # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook` - # would thwart this behavior. This hook checkpoints *only the iterators* - # and not the graph variables. - self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook( - estimator.model_dir, - save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access - save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access - checkpoint_basename=checkpoint_prefix + ".ckpt") - - # Name for the protocol buffer file that will contain the list of most - # recent checkpoints stored as a `CheckpointState` protocol buffer. - # This file, kept in the same directory as the checkpoint files, is - # automatically managed by the `Saver` to keep track of recent checkpoints. - # The default name used by the `Saver` for this file is "checkpoint". Here - # we use the name "checkpoint_" so that in case the - # `checkpoint_dir` is the same as the model checkpoint directory, there are - # no conflicts during restore. - self._latest_filename = "checkpoint_" + checkpoint_prefix - self._first_run = True - - def begin(self): - # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS` - # collection if no `Saver` or `Scaffold` is provided. - # pylint: disable=protected-access - if (self._checkpoint_saver_hook._saver is None and - self._checkpoint_saver_hook._scaffold is None): - iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS) - saveables = [_Saveable(i) for i in iterators] - self._checkpoint_saver_hook._saver = _CustomSaver(saveables, - self._latest_filename) - # pylint: enable=protected-access - self._checkpoint_saver_hook.begin() - - def _restore_or_save_initial_ckpt(self, session): - # Ideally this should be run in after_create_session but is not for the - # following reason: - # Currently there is no way of enforcing an order of running the - # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook` - # is run *after* this hook. That is troublesome because - # 1. If a checkpoint exists and this hook restores it, the initializer hook - # will override it. - # 2. If no checkpoint exists, this hook will try to save an initialized - # iterator which will result in an exception. - # - # As a temporary fix we enter the following implicit contract between this - # hook and the _DatasetInitializerHook. - # 1. The _DatasetInitializerHook initializes the iterator in the call to - # after_create_session. - # 2. This hook saves the iterator on the first call to `before_run()`, which - # is guaranteed to happen after `after_create_session()` of all hooks - # have been run. - - # Check if there is an existing checkpoint. If so, restore from it. - # pylint: disable=protected-access - latest_checkpoint_path = checkpoint_management.latest_checkpoint( - self._checkpoint_saver_hook._checkpoint_dir, - latest_filename=self._latest_filename) - if latest_checkpoint_path: - self._checkpoint_saver_hook._get_saver().restore(session, - latest_checkpoint_path) - else: - # The checkpoint saved here is the state at step "global_step". - # Note: We do not save the GraphDef or MetaGraphDef here. - global_step = session.run(self._checkpoint_saver_hook._global_step_tensor) - self._checkpoint_saver_hook._save(session, global_step) - self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step) - # pylint: enable=protected-access - - def before_run(self, run_context): - if self._first_run: - self._restore_or_save_initial_ckpt(run_context.session) - self._first_run = False - return self._checkpoint_saver_hook.before_run(run_context) - - def after_run(self, run_context, run_values): - self._checkpoint_saver_hook.after_run(run_context, run_values) - - def end(self, session): - self._checkpoint_saver_hook.end(session) - - -class _CustomSaver(saver_lib.Saver): - """`Saver` with a different default `latest_filename`. - - This is used in the `CheckpointInputPipelineHook` to avoid conflicts with - the model ckpt saved by the `CheckpointSaverHook`. - """ - - def __init__(self, var_list, latest_filename): - super(_CustomSaver, self).__init__(var_list) - self._latest_filename = latest_filename - - def save(self, - sess, - save_path, - global_step=None, - latest_filename=None, - meta_graph_suffix="meta", - write_meta_graph=True, - write_state=True, - strip_default_attrs=False): - return super(_CustomSaver, self).save( - sess, save_path, global_step, latest_filename or self._latest_filename, - meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs) + super(CheckpointInputPipelineHook, self).__init__(estimator) diff --git a/tensorflow/contrib/data/python/ops/parsing_ops.py b/tensorflow/contrib/data/python/ops/parsing_ops.py index cfbba701b06d9a55e72e038702e0fb45c904f6ba..3aeee9d8e42dce5af133afeeab4a8c97e50d5571 100644 --- a/tensorflow/contrib/data/python/ops/parsing_ops.py +++ b/tensorflow/contrib/data/python/ops/parsing_ops.py @@ -17,92 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import parsing_ops +from tensorflow.python.data.experimental.ops import parsing_ops +from tensorflow.python.util import deprecation -class _ParseExampleDataset(dataset_ops.UnaryDataset): - """A `Dataset` that parses `example` dataset into a `dict` dataset.""" - - def __init__(self, input_dataset, features, num_parallel_calls): - super(_ParseExampleDataset, self).__init__(input_dataset) - self._input_dataset = input_dataset - if not all(types == dtypes.string - for types in nest.flatten(input_dataset.output_types)): - raise TypeError("Input dataset should be a dataset of vectors of strings") - self._num_parallel_calls = num_parallel_calls - # pylint: disable=protected-access - self._features = parsing_ops._prepend_none_dimension(features) - # sparse_keys and dense_keys come back sorted here. - (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults, - dense_shapes) = parsing_ops._features_to_raw_params( - self._features, [ - parsing_ops.VarLenFeature, parsing_ops.SparseFeature, - parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature - ]) - # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature. - (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes, - dense_shape_as_shape) = parsing_ops._process_raw_parameters( - None, dense_defaults, sparse_keys, sparse_types, dense_keys, - dense_types, dense_shapes) - # pylint: enable=protected-access - self._sparse_keys = sparse_keys - self._sparse_types = sparse_types - self._dense_keys = dense_keys - self._dense_defaults = dense_defaults_vec - self._dense_shapes = dense_shapes - self._dense_types = dense_types - dense_output_shapes = [ - self._input_dataset.output_shapes.concatenate(shape) - for shape in dense_shape_as_shape - ] - sparse_output_shapes = [ - self._input_dataset.output_shapes.concatenate([None]) - for _ in range(len(sparse_keys)) - ] - - self._output_shapes = dict( - zip(self._dense_keys + self._sparse_keys, - dense_output_shapes + sparse_output_shapes)) - self._output_types = dict( - zip(self._dense_keys + self._sparse_keys, - self._dense_types + self._sparse_types)) - self._output_classes = dict( - zip(self._dense_keys + self._sparse_keys, - [ops.Tensor for _ in range(len(self._dense_defaults))] + - [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys)) - ])) - - def _as_variant_tensor(self): - return gen_dataset_ops.parse_example_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - self._num_parallel_calls, - self._dense_defaults, - self._sparse_keys, - self._dense_keys, - self._sparse_types, - self._dense_shapes, - **dataset_ops.flat_structure(self)) - - @property - def output_shapes(self): - return self._output_shapes - - @property - def output_types(self): - return self._output_types - - @property - def output_classes(self): - return self._output_classes - - -# TODO(b/111553342): add arguments names and example names as well. +@deprecation.deprecated( + None, "Use `tf.data.experimental.parse_example_dataset(...)`.") def parse_example_dataset(features, num_parallel_calls=1): """A transformation that parses `Example` protos into a `dict` of tensors. @@ -130,21 +50,4 @@ def parse_example_dataset(features, num_parallel_calls=1): Raises: ValueError: if features argument is None. """ - if features is None: - raise ValueError("Missing: features was %s." % features) - - def _apply_fn(dataset): - """Function from `Dataset` to `Dataset` that applies the transformation.""" - out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls) - if any([ - isinstance(feature, parsing_ops.SparseFeature) - for _, feature in features.items() - ]): - # pylint: disable=protected-access - # pylint: disable=g-long-lambda - out_dataset = out_dataset.map( - lambda x: parsing_ops._construct_sparse_tensors_for_sparse_features( - features, x), num_parallel_calls=num_parallel_calls) - return out_dataset - - return _apply_fn + return parsing_ops.parse_example_dataset(features, num_parallel_calls) diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py index 46f82e453a839a19ed4624e99206687fc0b5488f..adfb390cd9a6b159fe3887666993c6e9d6c758d8 100644 --- a/tensorflow/contrib/data/python/ops/prefetching_ops.py +++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py @@ -17,321 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import warnings - -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse -from tensorflow.python.eager import context -from tensorflow.python.framework import device as framework_device -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import functional_ops -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops -from tensorflow.python.ops import resource_variable_ops - - -def function_buffering_resource(string_arg, - target_device, - f, - buffer_size, - output_types, - container="", - shared_name=None, - name=None): - """Creates a FunctionBufferingResource. - - A FunctionBufferingResource fills up a buffer by calling a function `f` on - `target_device`. `f` should take in only a single string argument as input. - - Args: - string_arg: The single string argument to the function. - target_device: The device to run `f` on. - f: The function to be executed. - buffer_size: Size of the buffer to be populated. - output_types: The output types generated by the function. - container: (Optional) string. Defaults to "". - shared_name: (Optional) string. - name: (Optional) string to name the op. - - Returns: - Handle to a FunctionBufferingResource. - """ - if shared_name is None: - shared_name = "" - return ged_ops.experimental_function_buffering_resource( - string_arg=string_arg, - target_device=target_device, - shared_name=shared_name, - f=f, - buffer_size=buffer_size, - container=container, - name=name, - output_types=output_types) - - -def function_buffering_resource_get_next(function_buffer_resource, - output_types, - name=None): - return ged_ops.experimental_function_buffering_resource_get_next( - function_buffer_resource=function_buffer_resource, - output_types=output_types, - name=name) - - -def function_buffering_resource_reset(function_buffer_resource, name=None): - return ged_ops.experimental_function_buffering_resource_reset( - function_buffer_resource=function_buffer_resource, name=name) - - -# pylint: disable=protected-access -class _PrefetchToDeviceIterator(object): - """A replacement for `tf.data.Iterator` that prefetches to another device. - - Args: - input_dataset: The input dataset - one_shot: If true, we make a one shot iterator that's already initialized. - device: A fully specified device string where we want to prefetch to - buffer_size: Size of the prefetching buffer. - shared_name: (Optional.) If non-empty, the returned iterator will be - shared under the given name across multiple sessions that share the - same devices (e.g. when using a remote server). - - Returns: - An Iterator type object. - """ - - def __init__(self, - input_dataset, - one_shot, - device, - buffer_size, - shared_name=None): - self._input_dataset = input_dataset - self._get_next_call_count = 0 - self._one_shot = one_shot - if shared_name is None: - shared_name = "" - - if self._one_shot: - self._input_iterator = input_dataset.make_one_shot_iterator() - else: - self._input_iterator = iterator_ops.Iterator.from_structure( - self._input_dataset.output_types, self._input_dataset.output_shapes, - shared_name, self._input_dataset.output_classes) - input_iterator_handle = self._input_iterator.string_handle() - - @function.Defun(dtypes.string) - def _prefetch_fn(handle): - """Prefetches one element from `input_iterator`.""" - remote_iterator = iterator_ops.Iterator.from_string_handle( - handle, self._input_iterator.output_types, - self._input_iterator.output_shapes, - self._input_iterator.output_classes) - ret = remote_iterator.get_next() - return nest.flatten(sparse.serialize_sparse_tensors(ret)) - - iterator_device = ged_ops.experimental_iterator_get_device( - self._input_iterator._iterator_resource) - - with ops.device(device): - self._buffering_resource = function_buffering_resource( - f=_prefetch_fn, - target_device=iterator_device, - string_arg=input_iterator_handle, - buffer_size=buffer_size, - shared_name=shared_name, - output_types=nest.flatten( - sparse.as_dense_types(self._input_dataset.output_types, - self._input_dataset.output_classes))) - - if not self._one_shot: - reset_op = function_buffering_resource_reset(self._buffering_resource) - with ops.control_dependencies([reset_op]): - self._initializer = self._input_iterator.make_initializer( - self._input_dataset) - - def get_next(self, name=None): - """See `tf.data.Iterator.get_next`.""" - self._get_next_call_count += 1 - if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: - warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) - - flat_ret = ged_ops.experimental_function_buffering_resource_get_next( - self._buffering_resource, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - name=name) - - ret = sparse.deserialize_sparse_tensors( - nest.pack_sequence_as(self.output_types, flat_ret), - self.output_types, self.output_shapes, self.output_classes) - - for tensor, shape in zip( - nest.flatten(ret), nest.flatten(self.output_shapes)): - if isinstance(tensor, ops.Tensor): - tensor.set_shape(shape) - - return ret - - @property - def initializer(self): - if self._one_shot: - raise NotImplementedError("Can't initialize a one_shot_iterator") - return self._initializer - - @property - def output_classes(self): - return self._input_dataset.output_classes - - @property - def output_shapes(self): - return self._input_dataset.output_shapes - - @property - def output_types(self): - return self._input_dataset.output_types - - -class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator): - """A replacement for `tf.data.Iterator` that prefetches to another device. - - Args: - input_dataset: The input dataset - one_shot: If true, we make a one shot iterator that's already initialized. - device: A fully specified device string where we want to prefetch to - buffer_size: Size of the prefetching buffer. - shared_name: (Optional.) If non-empty, the returned iterator will be - shared under the given name across multiple sessions that share the - same devices (e.g. when using a remote server). - - Returns: - An Iterator type object. - """ - - def __init__(self, - input_dataset, - device, - buffer_size): - with ops.device("/device:CPU:0"): - super(_PrefetchToDeviceEagerIterator, self).__init__(input_dataset) - input_iterator_handle = gen_dataset_ops.iterator_to_string_handle( - self._resource) - - self._device = device - - @function.Defun(dtypes.string) - def _prefetch_fn(handle): - """Prefetches one element from `input_iterator`.""" - remote_iterator = iterator_ops.Iterator.from_string_handle( - handle, self.output_types, self.output_shapes, self.output_classes) - ret = remote_iterator.get_next() - return nest.flatten(sparse.serialize_sparse_tensors(ret)) - - _prefetch_fn.add_to_graph(None) - - with ops.device(device): - self._buffering_resource = function_buffering_resource( - f=_prefetch_fn, - output_types=self._flat_output_types, - target_device=ged_ops.experimental_iterator_get_device( - self._resource), - string_arg=input_iterator_handle, - buffer_size=buffer_size, - shared_name=iterator_ops._generate_shared_name( - "function_buffer_resource")) - - def _next_internal(self): - """Returns a nested structure of `tf.Tensor`s containing the next element. - """ - # This runs in sync mode as iterators use an error status to communicate - # that there is no more data to iterate over. - # TODO(b/77291417): Fix - with context.execution_mode(context.SYNC): - with ops.device(self._device): - ret = ged_ops.experimental_function_buffering_resource_get_next( - function_buffer_resource=self._buffering_resource, - output_types=self._flat_output_types) - return sparse.deserialize_sparse_tensors( - nest.pack_sequence_as(self._output_types, ret), self._output_types, - self._output_shapes, self._output_classes) -# pylint: enable=protected-access - - -class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset): - """A `Dataset` whose iterator prefetches elements to another device.""" - - def __init__(self, input_dataset, device, buffer_size): - super(_PrefetchToDeviceDataset, self).__init__(input_dataset) - self._input_dataset = input_dataset - self._device = device - self._buffer_size = buffer_size if buffer_size is not None else 1 - - # The static analysis cannot tell that the eager iterator's superclass has - # a `next()` method. - # pylint: disable=non-iterator-returned - def __iter__(self): - """Creates an `Iterator` for enumerating the elements of this dataset. - - The returned iterator implements the Python iterator protocol and therefore - can only be used in eager mode. - - Returns: - An `Iterator` over the elements of this dataset. - - Raises: - RuntimeError: If eager execution is enabled. - """ - if context.executing_eagerly(): - return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device, - self._buffer_size) - else: - raise RuntimeError("dataset.__iter__() is only supported when eager " - "execution is enabled.") - # pylint: enable=non-iterator-returned - - def make_one_shot_iterator(self): - if context.executing_eagerly(): - return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device, - self._buffer_size) - else: - return _PrefetchToDeviceIterator(self._input_dataset, one_shot=True, - device=self._device, - buffer_size=self._buffer_size) - - def make_initializable_iterator(self, shared_name=None): - return _PrefetchToDeviceIterator( - self._input_dataset, - one_shot=False, - device=self._device, - buffer_size=self._buffer_size, - shared_name=shared_name) - - def _as_variant_tensor(self): - # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset - # transformation methods is called. - # TODO(mrry): Investigate support for chaining further transformations after - # the prefetch, including GPU support. - raise NotImplementedError("`prefetch_to_device()` must be the last " - "transformation in a dataset pipeline.") - - @property - def output_types(self): - return self._input_dataset.output_types - - @property - def output_shapes(self): - return self._input_dataset.output_shapes - - @property - def output_classes(self): - return self._input_dataset.output_classes +from tensorflow.python.data.experimental.ops import prefetching_ops +from tensorflow.python.util import deprecation +@deprecation.deprecated(None, + "Use `tf.data.experimental.prefetch_to_device(...)`.") def prefetch_to_device(device, buffer_size=None): """A transformation that prefetches dataset values to the given `device`. @@ -347,12 +38,10 @@ def prefetch_to_device(device, buffer_size=None): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - def _apply_fn(dataset): - return _PrefetchToDeviceDataset(dataset, device, buffer_size) - - return _apply_fn + return prefetching_ops.prefetch_to_device(device, buffer_size) +@deprecation.deprecated(None, "Use `tf.data.experimental.copy_to_device(...)`.") def copy_to_device(target_device, source_device="/cpu:0"): """A transformation that copies dataset elements to the given `target_device`. @@ -364,165 +53,4 @@ def copy_to_device(target_device, source_device="/cpu:0"): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - - def _apply_fn(dataset): - return _CopyToDeviceDataset( - dataset, target_device=target_device, source_device=source_device) - - return _apply_fn - - -# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate -# all inputs to the Op are in host memory, thereby avoiding some unnecessary -# Sends and Recvs. -class _CopyToDeviceDataset(dataset_ops.UnaryDataset): - """A `Dataset` that copies elements to another device.""" - - def __init__(self, input_dataset, target_device, source_device="/cpu:0"): - """Constructs a _CopyToDeviceDataset. - - Args: - input_dataset: `Dataset` to be copied - target_device: The name of the device to which elements would be copied. - source_device: Device where input_dataset would be placed. - """ - super(_CopyToDeviceDataset, self).__init__(input_dataset) - self._input_dataset = input_dataset - self._target_device = target_device - spec = framework_device.DeviceSpec().from_string(self._target_device) - self._is_gpu_target = (spec.device_type == "GPU") - self._source_device_string = source_device - self._source_device = ops.convert_to_tensor(source_device) - - self._flat_output_shapes = nest.flatten( - sparse.as_dense_shapes(self._input_dataset.output_shapes, - self._input_dataset.output_classes)) - self._flat_output_types = nest.flatten( - sparse.as_dense_types(self._input_dataset.output_types, - self._input_dataset.output_classes)) - - @function.Defun() - def _init_func(): - """Creates an iterator for the input dataset. - - Returns: - A `string` tensor that encapsulates the iterator created. - """ - # pylint: disable=protected-access - ds_variant = self._input_dataset._as_variant_tensor() - resource = gen_dataset_ops.anonymous_iterator( - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - with ops.control_dependencies( - [gen_dataset_ops.make_iterator(ds_variant, resource)]): - return gen_dataset_ops.iterator_to_string_handle(resource) - - @function.Defun() - def _remote_init_func(): - return functional_ops.remote_call( - target=self._source_device, - args=_init_func.captured_inputs, - Tout=[dtypes.string], - f=_init_func) - - self._init_func = _remote_init_func - self._init_captured_args = _remote_init_func.captured_inputs - - @function.Defun(dtypes.string) - def _next_func(string_handle): - """Calls get_next for created iterator. - - Args: - string_handle: An iterator string handle created by _init_func - Returns: - The elements generated from `input_dataset` - """ - with ops.device(self._source_device_string): - iterator = iterator_ops.Iterator.from_string_handle( - string_handle, self.output_types, self.output_shapes, - self.output_classes) - ret = iterator.get_next() - return nest.flatten(sparse.serialize_sparse_tensors(ret)) - - @function.Defun(dtypes.string) - def _remote_next_func(string_handle): - return functional_ops.remote_call( - target=self._source_device, - args=[string_handle] + _next_func.captured_inputs, - Tout=self._flat_output_types, - f=_next_func) - - self._next_func = _remote_next_func - self._next_captured_args = _remote_next_func.captured_inputs - - @function.Defun(dtypes.string) - def _finalize_func(string_handle): - """Destroys the iterator resource created. - - Args: - string_handle: An iterator string handle created by _init_func - Returns: - Tensor constant 0 - """ - iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( - string_handle, - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - with ops.control_dependencies([ - resource_variable_ops.destroy_resource_op( - iterator_resource, ignore_lookup_error=True)]): - return array_ops.constant(0, dtypes.int64) - - @function.Defun(dtypes.string) - def _remote_finalize_func(string_handle): - return functional_ops.remote_call( - target=self._source_device, - args=[string_handle] + _finalize_func.captured_inputs, - Tout=[dtypes.int64], - f=_finalize_func) - - self._finalize_func = _remote_finalize_func - self._finalize_captured_args = _remote_finalize_func.captured_inputs - - g = ops.get_default_graph() - _remote_init_func.add_to_graph(g) - _remote_next_func.add_to_graph(g) - _remote_finalize_func.add_to_graph(g) - # pylint: enable=protected-scope - - # The one_shot_iterator implementation needs a 0 arg _make_dataset function - # that thereby captures all the inputs required to create the dataset. Since - # there are strings that are inputs to the GeneratorDataset which can't be - # placed on a GPU, this fails for the GPU case. Therefore, disabling it for - # GPU - def make_one_shot_iterator(self): - if self._is_gpu_target: - raise ValueError("Cannot create a one shot iterator when using " - "`tf.contrib.data.copy_to_device()` on GPU. Please use " - "`Dataset.make_initializable_iterator()` instead.") - else: - return super(_CopyToDeviceDataset, self).make_one_shot_iterator() - - def _as_variant_tensor(self): - with ops.device(self._target_device): - return gen_dataset_ops.generator_dataset( - self._init_captured_args, - self._next_captured_args, - self._finalize_captured_args, - init_func=self._init_func, - next_func=self._next_func, - finalize_func=self._finalize_func, - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - - @property - def output_types(self): - return self._input_dataset.output_types - - @property - def output_shapes(self): - return self._input_dataset.output_shapes - - @property - def output_classes(self): - return self._input_dataset.output_classes + return prefetching_ops.copy_to_device(target_device, source_device) diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py index 344a0763c839f4a769177f7b76534326548bb543..2c951256368a5ffdbc2be424cef12eafc6ecd782 100644 --- a/tensorflow/contrib/data/python/ops/random_ops.py +++ b/tensorflow/contrib/data/python/ops/random_ops.py @@ -17,36 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import random_seed -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.data.experimental.ops import random_ops +from tensorflow.python.util import deprecation -class RandomDataset(dataset_ops.DatasetSource): +class RandomDataset(random_ops.RandomDataset): """A `Dataset` of pseudorandom values.""" + @deprecation.deprecated( + None, "Use `tf.data.experimental.RandomDataset(...)`.") def __init__(self, seed=None): - """A `Dataset` of pseudorandom values.""" - super(RandomDataset, self).__init__() - self._seed, self._seed2 = random_seed.get_seed(seed) - - def _as_variant_tensor(self): - return gen_dataset_ops.random_dataset( - seed=self._seed, - seed2=self._seed2, - **dataset_ops.flat_structure(self)) - - @property - def output_classes(self): - return ops.Tensor - - @property - def output_shapes(self): - return tensor_shape.scalar() - - @property - def output_types(self): - return dtypes.int64 + super(RandomDataset, self).__init__(seed) diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 360971e20008fa817e94a55234e3a566db91af06..4601376dff47e161962e92678883039c4b88bab7 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -17,295 +17,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import csv - -import numpy as np - -from tensorflow.contrib.data.python.ops import batching -from tensorflow.contrib.data.python.ops import interleave_ops -from tensorflow.contrib.data.python.ops import optimization -from tensorflow.contrib.data.python.ops import parsing_ops -from tensorflow.contrib.data.python.ops import shuffle_ops +from tensorflow.python.data.experimental.ops import optimization +from tensorflow.python.data.experimental.ops import readers from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers as core_readers -from tensorflow.python.data.util import convert from tensorflow.python.data.util import nest -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.lib.io import file_io -from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gen_experimental_dataset_ops -from tensorflow.python.platform import gfile from tensorflow.python.util import deprecation -_ACCEPTABLE_CSV_TYPES = (dtypes.float32, dtypes.float64, dtypes.int32, - dtypes.int64, dtypes.string) - - -def _is_valid_int32(str_val): - try: - # Checks equality to prevent int32 overflow - return dtypes.int32.as_numpy_dtype(str_val) == dtypes.int64.as_numpy_dtype( - str_val) - except (ValueError, OverflowError): - return False - - -def _is_valid_int64(str_val): - try: - dtypes.int64.as_numpy_dtype(str_val) - return True - except (ValueError, OverflowError): - return False - - -def _is_valid_float(str_val, float_dtype): - try: - return float_dtype.as_numpy_dtype(str_val) < np.inf - except ValueError: - return False - - -def _infer_type(str_val, na_value, prev_type): - """Given a string, infers its tensor type. - - Infers the type of a value by picking the least 'permissive' type possible, - while still allowing the previous type inference for this column to be valid. - - Args: - str_val: String value to infer the type of. - na_value: Additional string to recognize as a NA/NaN CSV value. - prev_type: Type previously inferred based on values of this column that - we've seen up till now. - Returns: - Inferred dtype. - """ - if str_val in ("", na_value): - # If the field is null, it gives no extra information about its type - return prev_type - - type_list = [ - dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string - ] # list of types to try, ordered from least permissive to most - - type_functions = [ - _is_valid_int32, - _is_valid_int64, - lambda str_val: _is_valid_float(str_val, dtypes.float32), - lambda str_val: _is_valid_float(str_val, dtypes.float64), - lambda str_val: True, - ] # Corresponding list of validation functions - - for i in range(len(type_list)): - validation_fn = type_functions[i] - if validation_fn(str_val) and (prev_type is None or - prev_type in type_list[:i + 1]): - return type_list[i] - - -def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header): - """Generator that yields rows of CSV file(s) in order.""" - for fn in filenames: - with file_io.FileIO(fn, "r") as f: - rdr = csv.reader( - f, - delimiter=field_delim, - quoting=csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE) - if header: - next(rdr) # Skip header lines - - for csv_row in rdr: - if len(csv_row) != num_cols: - raise ValueError( - "Problem inferring types: CSV row has different number of fields " - "than expected.") - yield csv_row - - -def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim, - na_value, header, num_rows_for_inference, - select_columns): - """Infers column types from the first N valid CSV records of files.""" - if select_columns is None: - select_columns = range(num_cols) - inferred_types = [None] * len(select_columns) - - for i, csv_row in enumerate( - _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header)): - if num_rows_for_inference is not None and i >= num_rows_for_inference: - break - - for j, col_index in enumerate(select_columns): - inferred_types[j] = _infer_type(csv_row[col_index], na_value, - inferred_types[j]) - - # Replace None's with a default type - inferred_types = [t or dtypes.string for t in inferred_types] - # Default to 0 or '' for null values - return [ - constant_op.constant([0 if t is not dtypes.string else ""], dtype=t) - for t in inferred_types - ] - - -def _infer_column_names(filenames, field_delim, use_quote_delim): - """Infers column names from first rows of files.""" - csv_kwargs = { - "delimiter": field_delim, - "quoting": csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE - } - with file_io.FileIO(filenames[0], "r") as f: - try: - column_names = next(csv.reader(f, **csv_kwargs)) - except StopIteration: - raise ValueError(("Received StopIteration when reading the header line " - "of %s. Empty file?") % filenames[0]) - - for name in filenames[1:]: - with file_io.FileIO(name, "r") as f: - try: - if next(csv.reader(f, **csv_kwargs)) != column_names: - raise ValueError( - "Files have different column names in the header row.") - except StopIteration: - raise ValueError(("Received StopIteration when reading the header line " - "of %s. Empty file?") % filenames[0]) - return column_names - - -def _get_sorted_col_indices(select_columns, column_names): - """Transforms select_columns argument into sorted column indices.""" - names_to_indices = {n: i for i, n in enumerate(column_names)} - num_cols = len(column_names) - for i, v in enumerate(select_columns): - if isinstance(v, int): - if v < 0 or v >= num_cols: - raise ValueError( - "Column index %d specified in select_columns out of valid range." % - v) - continue - if v not in names_to_indices: - raise ValueError( - "Value '%s' specified in select_columns not a valid column index or " - "name." % v) - select_columns[i] = names_to_indices[v] - - # Sort and ensure there are no duplicates - result = sorted(set(select_columns)) - if len(result) != len(select_columns): - raise ValueError("select_columns contains duplicate columns") - return result - - -def _maybe_shuffle_and_repeat( - dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed): - """Optionally shuffle and repeat dataset, as requested.""" - if num_epochs != 1 and shuffle: - # Use shuffle_and_repeat for perf - return dataset.apply( - shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs, - shuffle_seed)) - elif shuffle: - return dataset.shuffle(shuffle_buffer_size, shuffle_seed) - elif num_epochs != 1: - return dataset.repeat(num_epochs) - return dataset - - -def make_tf_record_dataset(file_pattern, - batch_size, - parser_fn=None, - num_epochs=None, - shuffle=True, - shuffle_buffer_size=None, - shuffle_seed=None, - prefetch_buffer_size=optimization.AUTOTUNE, - num_parallel_reads=None, - num_parallel_parser_calls=None, - drop_final_batch=False): - """Reads and optionally parses TFRecord files into a dataset. - - Provides common functionality such as batching, optional parsing, shuffling, - and performant defaults. - - Args: - file_pattern: List of files or patterns of TFRecord file paths. - See `tf.gfile.Glob` for pattern rules. - batch_size: An int representing the number of records to combine - in a single batch. - parser_fn: (Optional.) A function accepting string input to parse - and process the record contents. This function must map records - to components of a fixed shape, so they may be batched. By - default, uses the record contents unmodified. - num_epochs: (Optional.) An int specifying the number of times this - dataset is repeated. If None (the default), cycles through the - dataset forever. - shuffle: (Optional.) A bool that indicates whether the input - should be shuffled. Defaults to `True`. - shuffle_buffer_size: (Optional.) Buffer size to use for - shuffling. A large buffer size ensures better shuffling, but - increases memory usage and startup time. - shuffle_seed: (Optional.) Randomization seed to use for shuffling. - prefetch_buffer_size: (Optional.) An int specifying the number of - feature batches to prefetch for performance improvement. - Defaults to auto-tune. Set to 0 to disable prefetching. - num_parallel_reads: (Optional.) Number of threads used to read - records from files. By default or if set to a value >1, the - results will be interleaved. - num_parallel_parser_calls: (Optional.) Number of parallel - records to parse in parallel. Defaults to an automatic selection. - drop_final_batch: (Optional.) Whether the last batch should be - dropped in case its size is smaller than `batch_size`; the - default behavior is not to drop the smaller batch. - - Returns: - A dataset, where each element matches the output of `parser_fn` - except it will have an additional leading `batch-size` dimension, - or a `batch_size`-length 1-D tensor of strings if `parser_fn` is - unspecified. - """ - files = dataset_ops.Dataset.list_files( - file_pattern, shuffle=shuffle, seed=shuffle_seed) - - if num_parallel_reads is None: - # Note: We considered auto-tuning this value, but there is a concern - # that this affects the mixing of records from different files, which - # could affect training convergence/accuracy, so we are defaulting to - # a constant for now. - num_parallel_reads = 24 - dataset = core_readers.TFRecordDataset( - files, num_parallel_reads=num_parallel_reads) - - if shuffle_buffer_size is None: - # TODO(josh11b): Auto-tune this value when not specified - shuffle_buffer_size = 10000 - dataset = _maybe_shuffle_and_repeat( - dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) - - # NOTE(mrry): We set `drop_final_batch=True` when `num_epochs is None` to - # improve the shape inference, because it makes the batch dimension static. - # It is safe to do this because in that case we are repeating the input - # indefinitely, and all batches will be full-sized. - drop_final_batch = drop_final_batch or num_epochs is None - - if parser_fn is None: - dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch) - else: - # TODO(josh11b): if num_parallel_parser_calls is None, use some function - # of num cores instead of map_and_batch's default behavior of one batch. - dataset = dataset.apply(batching.map_and_batch( - parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls, - drop_remainder=drop_final_batch)) - - if prefetch_buffer_size == 0: - return dataset - else: - return dataset.prefetch(buffer_size=prefetch_buffer_size) - +@deprecation.deprecated(None, + "Use `tf.data.experimental.make_csv_dataset(...)`.") def make_csv_dataset( file_pattern, batch_size, @@ -387,7 +112,6 @@ def make_csv_dataset( prefetch_buffer_size: An int specifying the number of feature batches to prefetch for performance improvement. Recommended value is the number of batches consumed per training step. Defaults to auto-tune. - num_parallel_reads: Number of threads used to read CSV records from files. If >1, the results will be interleaved. sloppy: If `True`, reading performance will be improved at @@ -411,106 +135,18 @@ def make_csv_dataset( Raises: ValueError: If any of the arguments is malformed. """ - # Create dataset of all matching filenames - filenames = _get_file_names(file_pattern, False) - dataset = dataset_ops.Dataset.from_tensor_slices(filenames) - if shuffle: - dataset = dataset.shuffle(len(filenames), shuffle_seed) - - # Clean arguments; figure out column names and defaults + return readers.make_csv_dataset( + file_pattern, batch_size, column_names, column_defaults, label_name, + select_columns, field_delim, use_quote_delim, na_value, header, + num_epochs, shuffle, shuffle_buffer_size, shuffle_seed, + prefetch_buffer_size, num_parallel_reads, sloppy, num_rows_for_inference, + compression_type) - if column_names is None: - if not header: - raise ValueError("Cannot infer column names without a header line.") - # If column names are not provided, infer from the header lines - column_names = _infer_column_names(filenames, field_delim, use_quote_delim) - if len(column_names) != len(set(column_names)): - raise ValueError("Cannot have duplicate column names.") - if select_columns is not None: - select_columns = _get_sorted_col_indices(select_columns, column_names) - - if column_defaults is not None: - column_defaults = [ - constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x - for x in column_defaults - ] - else: - # If column defaults are not provided, infer from records at graph - # construction time - column_defaults = _infer_column_defaults( - filenames, len(column_names), field_delim, use_quote_delim, na_value, - header, num_rows_for_inference, select_columns) - - if select_columns is not None and len(column_defaults) != len(select_columns): - raise ValueError( - "If specified, column_defaults and select_columns must have same " - "length." - ) - if select_columns is not None and len(column_names) > len(select_columns): - # Pick the relevant subset of column names - column_names = [column_names[i] for i in select_columns] - - if label_name is not None and label_name not in column_names: - raise ValueError("`label_name` provided must be one of the columns.") - - def filename_to_dataset(filename): - return CsvDataset( - filename, - record_defaults=column_defaults, - field_delim=field_delim, - use_quote_delim=use_quote_delim, - na_value=na_value, - select_cols=select_columns, - header=header, - compression_type=compression_type, - ) - - def map_fn(*columns): - """Organizes columns into a features dictionary. - - Args: - *columns: list of `Tensor`s corresponding to one csv record. - Returns: - An OrderedDict of feature names to values for that particular record. If - label_name is provided, extracts the label feature to be returned as the - second element of the tuple. - """ - features = collections.OrderedDict(zip(column_names, columns)) - if label_name is not None: - label = features.pop(label_name) - return features, label - return features - - # Read files sequentially (if num_parallel_reads=1) or in parallel - dataset = dataset.apply( - interleave_ops.parallel_interleave( - filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy)) - - dataset = _maybe_shuffle_and_repeat( - dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) - - # Apply batch before map for perf, because map has high overhead relative - # to the size of the computation in each map. - # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to - # improve the shape inference, because it makes the batch dimension static. - # It is safe to do this because in that case we are repeating the input - # indefinitely, and all batches will be full-sized. - dataset = dataset.batch(batch_size=batch_size, - drop_remainder=num_epochs is None) - dataset = dataset_ops.MapDataset( - dataset, map_fn, use_inter_op_parallelism=False) - dataset = dataset.prefetch(prefetch_buffer_size) - - return dataset - - -_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024 # 4 MB - - -class CsvDataset(dataset_ops.DatasetSource): +class CsvDataset(readers.CsvDataset): """A Dataset comprising lines from one or more CSV files.""" + @deprecation.deprecated(None, "Use `tf.data.experimental.CsvDataset(...)`.") def __init__(self, filenames, record_defaults, @@ -521,140 +157,13 @@ class CsvDataset(dataset_ops.DatasetSource): use_quote_delim=True, na_value="", select_cols=None): - """Creates a `CsvDataset` by reading and decoding CSV files. - - The elements of this dataset correspond to records from the file(s). - RFC 4180 format is expected for CSV files - (https://tools.ietf.org/html/rfc4180) - Note that we allow leading and trailing spaces with int or float field. - - - For example, suppose we have a file 'my_file0.csv' with four CSV columns of - different data types: - ``` - abcdefg,4.28E10,5.55E6,12 - hijklmn,-5.3E14,,2 - ``` - - We can construct a CsvDataset from it as follows: - ```python - dataset = tf.contrib.data.CsvDataset( - "my_file*.csv", - [tf.float32, # Required field, use dtype or empty tensor - tf.constant([0.0], dtype=tf.float32), # Optional field, default to 0.0 - tf.int32, # Required field, use dtype or empty tensor - ], - select_cols=[1,2,3] # Only parse last three columns - ) - ``` - - The expected output of its iterations is: - ```python - next_element = dataset.make_one_shot_iterator().get_next() - with tf.Session() as sess: - while True: - try: - print(sess.run(next_element)) - except tf.errors.OutOfRangeError: - break - - >> (4.28e10, 5.55e6, 12) - >> (-5.3e14, 0.0, 2) - ``` - - Args: - filenames: A `tf.string` tensor containing one or more filenames. - record_defaults: A list of default values for the CSV fields. Each item in - the list is either a valid CSV `DType` (float32, float64, int32, int64, - string), or a `Tensor` object with one of the above types. One per - column of CSV data, with either a scalar `Tensor` default value for the - column if it is optional, or `DType` or empty `Tensor` if required. If - both this and `select_columns` are specified, these must have the same - lengths, and `column_defaults` is assumed to be sorted in order of - increasing column index. - compression_type: (Optional.) A `tf.string` scalar evaluating to one of - `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no - compression. - buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes - to buffer while reading files. Defaults to 4MB. - header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s) - have header line(s) that should be skipped when parsing. Defaults to - `False`. - field_delim: (Optional.) A `tf.string` scalar containing the delimiter - character that separates fields in a record. Defaults to `","`. - use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats - double quotation marks as regular characters inside of string fields - (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`. - na_value: (Optional.) A `tf.string` scalar indicating a value that will - be treated as NA/NaN. - select_cols: (Optional.) A sorted list of column indices to select from - the input data. If specified, only this subset of columns will be - parsed. Defaults to parsing all columns. - """ - super(CsvDataset, self).__init__() - self._filenames = ops.convert_to_tensor( - filenames, dtype=dtypes.string, name="filenames") - self._compression_type = convert.optional_param_to_tensor( - "compression_type", - compression_type, - argument_default="", - argument_dtype=dtypes.string) - record_defaults = [ - constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x - for x in record_defaults - ] - self._record_defaults = ops.convert_n_to_tensor( - record_defaults, name="record_defaults") - self._buffer_size = convert.optional_param_to_tensor( - "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES) - self._header = ops.convert_to_tensor( - header, dtype=dtypes.bool, name="header") - self._field_delim = ops.convert_to_tensor( - field_delim, dtype=dtypes.string, name="field_delim") - self._use_quote_delim = ops.convert_to_tensor( - use_quote_delim, dtype=dtypes.bool, name="use_quote_delim") - self._na_value = ops.convert_to_tensor( - na_value, dtype=dtypes.string, name="na_value") - self._select_cols = convert.optional_param_to_tensor( - "select_cols", - select_cols, - argument_default=[], - argument_dtype=dtypes.int64, - ) - self._output_shapes = tuple( - tensor_shape.scalar() for _ in range(len(record_defaults))) - self._output_types = tuple(d.dtype for d in self._record_defaults) - self._output_classes = tuple( - ops.Tensor for _ in range(len(record_defaults))) - - def _as_variant_tensor(self): - # Constructs graph node for the dataset op. - return gen_experimental_dataset_ops.experimental_csv_dataset( - filenames=self._filenames, - record_defaults=self._record_defaults, - buffer_size=self._buffer_size, - header=self._header, - output_shapes=self._output_shapes, - field_delim=self._field_delim, - use_quote_delim=self._use_quote_delim, - na_value=self._na_value, - select_cols=self._select_cols, - compression_type=self._compression_type, - ) - - @property - def output_types(self): - return self._output_types - - @property - def output_shapes(self): - return self._output_shapes - - @property - def output_classes(self): - return self._output_classes + super(CsvDataset, self).__init__( + filenames, record_defaults, compression_type, buffer_size, header, + field_delim, use_quote_delim, na_value, select_cols) +@deprecation.deprecated( + None, "Use `tf.data.experimental.make_batched_features_dataset(...)`.") def make_batched_features_dataset(file_pattern, batch_size, features, @@ -759,57 +268,15 @@ def make_batched_features_dataset(file_pattern, Raises: ValueError: If `label_key` is not one of the `features` keys. """ - # Create dataset of all matching filenames - filenames = _get_file_names(file_pattern, False) - dataset = dataset_ops.Dataset.from_tensor_slices(filenames) - if shuffle: - dataset = dataset.shuffle(len(filenames), shuffle_seed) - - # Read `Example` records from files as tensor objects. - if reader_args is None: - reader_args = [] + return readers.make_batched_features_dataset( + file_pattern, batch_size, features, reader, label_key, reader_args, + num_epochs, shuffle, shuffle_buffer_size, shuffle_seed, + prefetch_buffer_size, reader_num_threads, parser_num_threads, + sloppy_ordering, drop_final_batch) - # Read files sequentially (if reader_num_threads=1) or in parallel - dataset = dataset.apply( - interleave_ops.parallel_interleave( - lambda filename: reader(filename, *reader_args), - cycle_length=reader_num_threads, - sloppy=sloppy_ordering)) - # Extract values if the `Example` tensors are stored as key-value tuples. - if dataset.output_types == (dtypes.string, dtypes.string): - dataset = dataset_ops.MapDataset( - dataset, lambda _, v: v, use_inter_op_parallelism=False) - - # Apply dataset repeat and shuffle transformations. - dataset = _maybe_shuffle_and_repeat( - dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) - - # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to - # improve the shape inference, because it makes the batch dimension static. - # It is safe to do this because in that case we are repeating the input - # indefinitely, and all batches will be full-sized. - dataset = dataset.batch( - batch_size, drop_remainder=drop_final_batch or num_epochs is None) - - # Parse `Example` tensors to a dictionary of `Feature` tensors. - dataset = dataset.apply( - parsing_ops.parse_example_dataset( - features, num_parallel_calls=parser_num_threads)) - - if label_key: - if label_key not in features: - raise ValueError( - "The `label_key` provided (%r) must be one of the `features` keys." % - label_key) - dataset = dataset.map(lambda x: (x, x.pop(label_key))) - - dataset = dataset.prefetch(prefetch_buffer_size) - return dataset - - -@deprecation.deprecated(None, - "Use `tf.contrib.data.make_batched_features_dataset`") +@deprecation.deprecated( + None, "Use `tf.data.experimental.make_batched_features_dataset(...)`") def read_batch_features(file_pattern, batch_size, features, @@ -879,7 +346,7 @@ def read_batch_features(file_pattern, Returns: A dict from keys in features to `Tensor` or `SparseTensor` objects. """ - dataset = make_batched_features_dataset( + dataset = readers.make_batched_features_dataset( file_pattern, batch_size, features, @@ -893,96 +360,13 @@ def read_batch_features(file_pattern, return outputs -def _get_file_names(file_pattern, shuffle): - """Parse list of file names from pattern, optionally shuffled. - - Args: - file_pattern: File glob pattern, or list of glob patterns. - shuffle: Whether to shuffle the order of file names. - - Returns: - List of file names matching `file_pattern`. - - Raises: - ValueError: If `file_pattern` is empty, or pattern matches no files. - """ - if isinstance(file_pattern, list): - if not file_pattern: - raise ValueError("File pattern is empty.") - file_names = [] - for entry in file_pattern: - file_names.extend(gfile.Glob(entry)) - else: - file_names = list(gfile.Glob(file_pattern)) - - if not file_names: - raise ValueError("No files match %s." % file_pattern) - - # Sort files so it will be deterministic for unit tests. - if not shuffle: - file_names = sorted(file_names) - return file_names - - -class SqlDataset(dataset_ops.DatasetSource): +class SqlDataset(readers.SqlDataset): """A `Dataset` consisting of the results from a SQL query.""" + @deprecation.deprecated(None, "Use `tf.data.experimental.SqlDataset(...)`.") def __init__(self, driver_name, data_source_name, query, output_types): - """Creates a `SqlDataset`. - - `SqlDataset` allows a user to read data from the result set of a SQL query. - For example: - - ```python - dataset = tf.contrib.data.SqlDataset("sqlite", "/foo/bar.sqlite3", - "SELECT name, age FROM people", - (tf.string, tf.int32)) - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() - # Prints the rows of the result set of the above query. - while True: - try: - print(sess.run(next_element)) - except tf.errors.OutOfRangeError: - break - ``` - - Args: - driver_name: A 0-D `tf.string` tensor containing the database type. - Currently, the only supported value is 'sqlite'. - data_source_name: A 0-D `tf.string` tensor containing a connection string - to connect to the database. - query: A 0-D `tf.string` tensor containing the SQL query to execute. - output_types: A tuple of `tf.DType` objects representing the types of the - columns returned by `query`. - """ - super(SqlDataset, self).__init__() - self._driver_name = ops.convert_to_tensor( - driver_name, dtype=dtypes.string, name="driver_name") - self._data_source_name = ops.convert_to_tensor( - data_source_name, dtype=dtypes.string, name="data_source_name") - self._query = ops.convert_to_tensor( - query, dtype=dtypes.string, name="query") - self._output_types = output_types - - def _as_variant_tensor(self): - return gen_dataset_ops.sql_dataset(self._driver_name, - self._data_source_name, self._query, - nest.flatten(self.output_types), - nest.flatten(self.output_shapes)) - - @property - def output_classes(self): - return nest.map_structure(lambda _: ops.Tensor, self._output_types) - - @property - def output_shapes(self): - return nest.map_structure(lambda _: tensor_shape.TensorShape([]), - self._output_types) - - @property - def output_types(self): - return self._output_types + super(SqlDataset, self).__init__( + driver_name, data_source_name, query, output_types) class LMDBDataset(dataset_ops.DatasetSource): diff --git a/tensorflow/contrib/data/python/ops/resampling.py b/tensorflow/contrib/data/python/ops/resampling.py index 75642f143e19c3d77e675384362c4dab94e10932..29d77528d95ba62783c1f7c1c0df530ed3929c9e 100644 --- a/tensorflow/contrib/data/python/ops/resampling.py +++ b/tensorflow/contrib/data/python/ops/resampling.py @@ -17,22 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - -from tensorflow.contrib.data.python.ops import batching -from tensorflow.contrib.data.python.ops import interleave_ops -from tensorflow.contrib.data.python.ops import scan_ops -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import logging_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops +from tensorflow.python.data.experimental.ops import resampling +from tensorflow.python.util import deprecation +@deprecation.deprecated(None, + "Use `tf.data.experimental.rejection_resample(...)`.") def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): """A transformation that resamples a dataset to achieve a target distribution. @@ -52,243 +42,5 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - def _apply_fn(dataset): - """Function from `Dataset` to `Dataset` that applies the transformation.""" - target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist") - class_values_ds = dataset.map(class_func) - - # Get initial distribution. - if initial_dist is not None: - initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist") - acceptance_dist, prob_of_original = ( - _calculate_acceptance_probs_with_mixing(initial_dist_t, - target_dist_t)) - initial_dist_ds = dataset_ops.Dataset.from_tensors( - initial_dist_t).repeat() - acceptance_dist_ds = dataset_ops.Dataset.from_tensors( - acceptance_dist).repeat() - prob_of_original_ds = dataset_ops.Dataset.from_tensors( - prob_of_original).repeat() - else: - initial_dist_ds = _estimate_initial_dist_ds( - target_dist_t, class_values_ds) - acceptance_and_original_prob_ds = initial_dist_ds.map( - lambda initial: _calculate_acceptance_probs_with_mixing( - initial, target_dist_t)) - acceptance_dist_ds = acceptance_and_original_prob_ds.map( - lambda accept_prob, _: accept_prob) - prob_of_original_ds = acceptance_and_original_prob_ds.map( - lambda _, prob_original: prob_original) - filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, - class_values_ds, seed) - # Prefetch filtered dataset for speed. - filtered_ds = filtered_ds.prefetch(3) - - prob_original_static = _get_prob_original_static( - initial_dist_t, target_dist_t) if initial_dist is not None else None - if prob_original_static == 1: - return dataset_ops.Dataset.zip((class_values_ds, dataset)) - elif prob_original_static == 0: - return filtered_ds - else: - return interleave_ops.sample_from_datasets( - [dataset_ops.Dataset.zip((class_values_ds, dataset)), filtered_ds], - weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]), - seed=seed) - - return _apply_fn - - -def _get_prob_original_static(initial_dist_t, target_dist_t): - """Returns the static probability of sampling from the original. - - `tensor_util.constant_value(prob_of_original)` returns `None` if it encounters - an Op that it isn't defined for. We have some custom logic to avoid this. - - Args: - initial_dist_t: A tensor of the initial distribution. - target_dist_t: A tensor of the target distribution. - - Returns: - The probability of sampling from the original distribution as a constant, - if it is a constant, or `None`. - """ - init_static = tensor_util.constant_value(initial_dist_t) - target_static = tensor_util.constant_value(target_dist_t) - - if init_static is None or target_static is None: - return None - else: - return np.min(target_static / init_static) - - -def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds, - seed): - """Filters a dataset based on per-class acceptance probabilities. - - Args: - dataset: The dataset to be filtered. - acceptance_dist_ds: A dataset of acceptance probabilities. - initial_dist_ds: A dataset of the initial probability distribution, given or - estimated. - class_values_ds: A dataset of the corresponding classes. - seed: (Optional.) Python integer seed for the resampler. - - Returns: - A dataset of (class value, data) after filtering. - """ - def maybe_warn_on_large_rejection(accept_dist, initial_dist): - proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist) - return control_flow_ops.cond( - math_ops.less(proportion_rejected, .5), - lambda: accept_dist, - lambda: logging_ops.Print( # pylint: disable=g-long-lambda - accept_dist, [proportion_rejected, initial_dist, accept_dist], - message="Proportion of examples rejected by sampler is high: ", - summarize=100, - first_n=10)) - - acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds, - initial_dist_ds)) - .map(maybe_warn_on_large_rejection)) - - def _gather_and_copy(class_val, acceptance_prob, data): - return class_val, array_ops.gather(acceptance_prob, class_val), data - - current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip( - (class_values_ds, acceptance_dist_ds, dataset)).map(_gather_and_copy) - filtered_ds = ( - current_probabilities_and_class_and_data_ds - .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p)) - return filtered_ds.map(lambda class_value, _, data: (class_value, data)) - - -def _estimate_initial_dist_ds( - target_dist_t, class_values_ds, dist_estimation_batch_size=32, - smoothing_constant=10): - num_classes = (target_dist_t.shape[0].value or - array_ops.shape(target_dist_t)[0]) - initial_examples_per_class_seen = array_ops.fill( - [num_classes], np.int64(smoothing_constant)) - - def update_estimate_and_tile(num_examples_per_class_seen, c): - updated_examples_per_class_seen, dist = _estimate_data_distribution( - c, num_examples_per_class_seen) - tiled_dist = array_ops.tile( - array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1]) - return updated_examples_per_class_seen, tiled_dist - - initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size) - .apply(scan_ops.scan(initial_examples_per_class_seen, - update_estimate_and_tile)) - .apply(batching.unbatch())) - - return initial_dist_ds - - -def _get_target_to_initial_ratio(initial_probs, target_probs): - # Add tiny to initial_probs to avoid divide by zero. - denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny) - return target_probs / denom - - -def _estimate_data_distribution(c, num_examples_per_class_seen): - """Estimate data distribution as labels are seen. - - Args: - c: The class labels. Type `int32`, shape `[batch_size]`. - num_examples_per_class_seen: Type `int64`, shape `[num_classes]`, - containing counts. - - Returns: - num_examples_per_lass_seen: Updated counts. Type `int64`, shape - `[num_classes]`. - dist: The updated distribution. Type `float32`, shape `[num_classes]`. - """ - num_classes = num_examples_per_class_seen.get_shape()[0].value - # Update the class-count based on what labels are seen in batch. - num_examples_per_class_seen = math_ops.add( - num_examples_per_class_seen, math_ops.reduce_sum( - array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0)) - init_prob_estimate = math_ops.truediv( - num_examples_per_class_seen, - math_ops.reduce_sum(num_examples_per_class_seen)) - dist = math_ops.cast(init_prob_estimate, dtypes.float32) - return num_examples_per_class_seen, dist - - -def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs): - """Calculates the acceptance probabilities and mixing ratio. - - In this case, we assume that we can *either* sample from the original data - distribution with probability `m`, or sample from a reshaped distribution - that comes from rejection sampling on the original distribution. This - rejection sampling is done on a per-class basis, with `a_i` representing the - probability of accepting data from class `i`. - - This method is based on solving the following analysis for the reshaped - distribution: - - Let F be the probability of a rejection (on any example). - Let p_i be the proportion of examples in the data in class i (init_probs) - Let a_i is the rate the rejection sampler should *accept* class i - Let t_i is the target proportion in the minibatches for class i (target_probs) - - ``` - F = sum_i(p_i * (1-a_i)) - = 1 - sum_i(p_i * a_i) using sum_i(p_i) = 1 - ``` - - An example with class `i` will be accepted if `k` rejections occur, then an - example with class `i` is seen by the rejector, and it is accepted. This can - be written as follows: - - ``` - t_i = sum_k=0^inf(F^k * p_i * a_i) - = p_i * a_j / (1 - F) using geometric series identity, since 0 <= F < 1 - = p_i * a_i / sum_j(p_j * a_j) using F from above - ``` - - Note that the following constraints hold: - ``` - 0 <= p_i <= 1, sum_i(p_i) = 1 - 0 <= a_i <= 1 - 0 <= t_i <= 1, sum_i(t_i) = 1 - ``` - - A solution for a_i in terms of the other variables is the following: - ```a_i = (t_i / p_i) / max_i[t_i / p_i]``` - - If we try to minimize the amount of data rejected, we get the following: - - M_max = max_i [ t_i / p_i ] - M_min = min_i [ t_i / p_i ] - - The desired probability of accepting data if it comes from class `i`: - - a_i = (t_i/p_i - m) / (M_max - m) - - The desired probability of pulling a data element from the original dataset, - rather than the filtered one: - - m = M_min - - Args: - initial_probs: A Tensor of the initial probability distribution, given or - estimated. - target_probs: A Tensor of the corresponding classes. - - Returns: - (A 1D Tensor with the per-class acceptance probabilities, the desired - probability of pull from the original distribution.) - """ - ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs) - max_ratio = math_ops.reduce_max(ratio_l) - min_ratio = math_ops.reduce_min(ratio_l) - - # Target prob to sample from original distribution. - m = min_ratio - - # TODO(joelshor): Simplify fraction, if possible. - a_i = (ratio_l - m) / (max_ratio - m) - return a_i, m + return resampling.rejection_resample(class_func, target_dist, initial_dist, + seed) diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py index c52582cd35565bbfab9f49760684608656f77223..0ca9fddb23b20995bdcd4d45aa675537111c4552 100644 --- a/tensorflow/contrib/data/python/ops/scan_ops.py +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -17,137 +17,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections - -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse -from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.ops import gen_dataset_ops - - -class _ScanDataset(dataset_ops.UnaryDataset): - """A dataset that scans a function across its input.""" - - def __init__(self, input_dataset, initial_state, scan_func): - """See `scan()` for details.""" - super(_ScanDataset, self).__init__(input_dataset) - self._input_dataset = input_dataset - - with ops.name_scope("initial_state"): - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - self._initial_state = nest.pack_sequence_as(initial_state, [ - sparse_tensor.SparseTensor.from_value(t) - if sparse_tensor.is_sparse(t) else ops.convert_to_tensor( - t, name="component_%d" % i) - for i, t in enumerate(nest.flatten(initial_state)) - ]) - - # Compute initial values for the state classes, shapes and types based on - # the initial state. The shapes may be refined by running `tf_scan_func` one - # or more times below. - self._state_classes = sparse.get_classes(self._initial_state) - self._state_shapes = nest.pack_sequence_as( - self._initial_state, - [t.get_shape() for t in nest.flatten(self._initial_state)]) - self._state_types = nest.pack_sequence_as( - self._initial_state, - [t.dtype for t in nest.flatten(self._initial_state)]) - - # Will be populated by calling `tf_scan_func`. - self._output_classes = None - self._output_shapes = None - self._output_types = None - - # Iteratively rerun the scan function until reaching a fixed point on - # `self._state_shapes`. - need_to_rerun = True - while need_to_rerun: - - wrapped_func = dataset_ops.StructuredFunctionWrapper( - scan_func, "tf.contrib.data.scan()", - input_classes=(self._state_classes, input_dataset.output_classes), - input_shapes=(self._state_shapes, input_dataset.output_shapes), - input_types=(self._state_types, input_dataset.output_types), - add_to_graph=False) - if not ( - isinstance(wrapped_func.output_types, collections.Sequence) and - len(wrapped_func.output_types) == 2): - raise TypeError("The scan function must return a pair comprising the " - "new state and the output value.") - - new_state_classes, self._output_classes = wrapped_func.output_classes - - # Extract and validate class information from the returned values. - for new_state_class, state_class in zip( - nest.flatten(new_state_classes), - nest.flatten(self._state_classes)): - if not issubclass(new_state_class, state_class): - raise TypeError( - "The element classes for the new state must match the initial " - "state. Expected %s; got %s." % - (self._state_classes, new_state_classes)) - - # Extract and validate type information from the returned values. - new_state_types, self._output_types = wrapped_func.output_types - for new_state_type, state_type in zip( - nest.flatten(new_state_types), nest.flatten(self._state_types)): - if new_state_type != state_type: - raise TypeError( - "The element types for the new state must match the initial " - "state. Expected %s; got %s." % - (self._state_types, new_state_types)) - - # Extract shape information from the returned values. - new_state_shapes, self._output_shapes = wrapped_func.output_shapes - - flat_state_shapes = nest.flatten(self._state_shapes) - flat_new_state_shapes = nest.flatten(new_state_shapes) - weakened_state_shapes = [ - original.most_specific_compatible_shape(new) - for original, new in zip(flat_state_shapes, flat_new_state_shapes) - ] - - need_to_rerun = False - for original_shape, weakened_shape in zip(flat_state_shapes, - weakened_state_shapes): - if original_shape.ndims is not None and ( - weakened_shape.ndims is None or - original_shape.as_list() != weakened_shape.as_list()): - need_to_rerun = True - break - - if need_to_rerun: - self._state_shapes = nest.pack_sequence_as(self._state_shapes, - weakened_state_shapes) - - self._scan_func = wrapped_func.function - self._scan_func.add_to_graph(ops.get_default_graph()) - - def _as_variant_tensor(self): - input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access - return gen_dataset_ops.scan_dataset( - input_t, - nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)), - self._scan_func.captured_inputs, - f=self._scan_func, - **dataset_ops.flat_structure(self)) - - @property - def output_classes(self): - return self._output_classes - - @property - def output_shapes(self): - return self._output_shapes - - @property - def output_types(self): - return self._output_types +from tensorflow.python.data.experimental.ops import scan_ops +from tensorflow.python.util import deprecation +@deprecation.deprecated(None, "Use `tf.data.experimental.scan(...)`.") def scan(initial_state, scan_func): """A transformation that scans a function across an input dataset. @@ -168,7 +42,4 @@ def scan(initial_state, scan_func): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - def _apply_fn(dataset): - return _ScanDataset(dataset, initial_state, scan_func) - - return _apply_fn + return scan_ops.scan(initial_state, scan_func) diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py index 985d1d87d0834a17f0a4f147ab8ace3e2db05f67..329b34fdfecf026688c3ebd210d3400a427940a8 100644 --- a/tensorflow/contrib/data/python/ops/shuffle_ops.py +++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py @@ -17,54 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import random_seed -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_dataset_ops - - -class _ShuffleAndRepeatDataset(dataset_ops.UnaryDataset): - """A `Dataset` that fuses `shuffle` and `repeat`.""" - - def __init__(self, input_dataset, buffer_size, count=None, seed=None): - super(_ShuffleAndRepeatDataset, self).__init__(input_dataset) - self._input_dataset = input_dataset - self._buffer_size = ops.convert_to_tensor( - buffer_size, dtype=dtypes.int64, name="buffer_size") - if count is None: - self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count") - else: - self._count = ops.convert_to_tensor( - count, dtype=dtypes.int64, name="count") - self._seed, self._seed2 = random_seed.get_seed(seed) - - def _as_variant_tensor(self): - # pylint: disable=protected-access - input_resource = self._input_dataset._as_variant_tensor() - return gen_dataset_ops.shuffle_and_repeat_dataset( - input_resource, - buffer_size=self._buffer_size, - count=self._count, - seed=self._seed, - seed2=self._seed2, - **dataset_ops.flat_structure(self)) - # pylint: enable=protected-access - - @property - def output_classes(self): - return self._input_dataset.output_classes - - @property - def output_shapes(self): - return self._input_dataset.output_shapes - - @property - def output_types(self): - return self._input_dataset.output_types +from tensorflow.python.data.experimental.ops import shuffle_ops +from tensorflow.python.util import deprecation +@deprecation.deprecated(None, + "Use `tf.data.experimental.shuffle_and_repeat(...)`.") def shuffle_and_repeat(buffer_size, count=None, seed=None): """Shuffles and repeats a Dataset returning a new permutation for each epoch. @@ -93,8 +51,4 @@ def shuffle_and_repeat(buffer_size, count=None, seed=None): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - - def _apply_fn(dataset): # pylint: disable=missing-docstring - return _ShuffleAndRepeatDataset(dataset, buffer_size, count, seed) - - return _apply_fn + return shuffle_ops.shuffle_and_repeat(buffer_size, count, seed) diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py index f73c3fd9cba6a19db395bd14e2eef3617158a82a..20cceb4647ae6d5f80a9dbac3baed72d50254f09 100644 --- a/tensorflow/contrib/data/python/ops/threadpool.py +++ b/tensorflow/contrib/data/python/ops/threadpool.py @@ -17,88 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import threading - -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.eager import context -from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops -from tensorflow.python.ops import resource_variable_ops - -_uid_counter = 0 -_uid_lock = threading.Lock() - - -def _generate_shared_name(prefix): - with _uid_lock: - global _uid_counter - uid = _uid_counter - _uid_counter += 1 - return "{}{}".format(prefix, uid) - - -# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable -# or make private / remove. -class PrivateThreadPool(object): - """A stateful resource that represents a private thread pool.""" - - def __init__(self, num_threads, display_name=None, - max_intra_op_parallelism=1): - """Creates a `PrivateThreadPool` with the given number of threads.""" - if context.executing_eagerly(): - shared_name = _generate_shared_name("privatethreadpool") - self._resource = ged_ops.experimental_thread_pool_handle( - num_threads=num_threads, - max_intra_op_parallelism=max_intra_op_parallelism, - display_name=display_name, - shared_name=shared_name) - self._resource_deleter = resource_variable_ops.EagerResourceDeleter( - handle=self._resource, handle_device=context.context().device_name) - else: - self._resource = ged_ops.experimental_thread_pool_handle( - num_threads=num_threads, - max_intra_op_parallelism=max_intra_op_parallelism, - display_name=display_name) - - -class _ThreadPoolDataset(dataset_ops.UnaryDataset): - """A `Dataset` that acts as an identity, and sets a custom threadpool.""" - - def __init__(self, input_dataset, thread_pool): - super(_ThreadPoolDataset, self).__init__(input_dataset) - self._input_dataset = input_dataset - self._thread_pool = thread_pool - - def _as_variant_tensor(self): - return ged_ops.experimental_thread_pool_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - self._thread_pool._resource, # pylint: disable=protected-access - **dataset_ops.flat_structure(self)) - - @property - def output_shapes(self): - return self._input_dataset.output_shapes - - @property - def output_types(self): - return self._input_dataset.output_types - - @property - def output_classes(self): - return self._input_dataset.output_classes - - -# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable -# or make private / remove. -def override_threadpool(dataset, thread_pool): - """Returns a new dataset that uses the given thread pool for its operations. - - Args: - dataset: A `tf.data.Dataset` object. - thread_pool: A `PrivateThreadPool` object. - - Returns: - A dataset containing the same values as `dataset`, but which uses - `thread_pool` to compute any of its parallel operations (such as - `tf.data.Dataset.map`). - """ - return _ThreadPoolDataset(dataset, thread_pool) +# pylint: disable=unused-import +from tensorflow.python.data.experimental.ops.threadpool import override_threadpool +from tensorflow.python.data.experimental.ops.threadpool import PrivateThreadPool diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py index ed363a7090748cfc0742d3ada18c5527750c2a67..909d06c677ea29733966e0c19a7543b149d2fe74 100644 --- a/tensorflow/contrib/data/python/ops/unique.py +++ b/tensorflow/contrib/data/python/ops/unique.py @@ -17,11 +17,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import gen_experimental_dataset_ops +from tensorflow.python.data.experimental.ops import unique as experimental_unique +from tensorflow.python.util import deprecation +@deprecation.deprecated(None, "Use `tf.data.experimental.unique()`.") def unique(): """Creates a `Dataset` from another `Dataset`, discarding duplicates. @@ -39,39 +39,4 @@ def unique(): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - - def _apply_fn(dataset): - return _UniqueDataset(dataset) - - return _apply_fn - - -class _UniqueDataset(dataset_ops.UnaryDataset): - """A `Dataset` contains the unique elements from its input.""" - - def __init__(self, input_dataset): - """See `unique()` for details.""" - super(_UniqueDataset, self).__init__(input_dataset) - self._input_dataset = input_dataset - if input_dataset.output_types not in (dtypes.int32, dtypes.int64, - dtypes.string): - raise TypeError( - "`tf.contrib.data.unique()` only supports inputs with a single " - "`tf.int32`, `tf.int64`, or `tf.string` component.") - - def _as_variant_tensor(self): - return gen_experimental_dataset_ops.experimental_unique_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - **dataset_ops.flat_structure(self)) - - @property - def output_classes(self): - return self._input_dataset.output_classes - - @property - def output_shapes(self): - return self._input_dataset.output_shapes - - @property - def output_types(self): - return self._input_dataset.output_types + return experimental_unique.unique() diff --git a/tensorflow/contrib/data/python/ops/writers.py b/tensorflow/contrib/data/python/ops/writers.py index c455fdcba673853079ff0d162c4799e72bc8e627..42fb69bf077afbd2094f6eb1bf3fe7b17f761910 100644 --- a/tensorflow/contrib/data/python/ops/writers.py +++ b/tensorflow/contrib/data/python/ops/writers.py @@ -17,42 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import convert -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.data.experimental.ops import writers +from tensorflow.python.util import deprecation -class TFRecordWriter(object): +class TFRecordWriter(writers.TFRecordWriter): """Writes data to a TFRecord file.""" + @deprecation.deprecated( + None, "Use `tf.data.experimental.TFRecordWriter(...)`.") def __init__(self, filename, compression_type=None): - self._filename = ops.convert_to_tensor( - filename, dtypes.string, name="filename") - self._compression_type = convert.optional_param_to_tensor( - "compression_type", - compression_type, - argument_default="", - argument_dtype=dtypes.string) - - def write(self, dataset): - """Returns a `tf.Operation` to write a dataset to a file. - - Args: - dataset: a `tf.data.Dataset` whose elements are to be written to a file - - Returns: - A `tf.Operation` that, when run, writes contents of `dataset` to a file. - """ - if not isinstance(dataset, dataset_ops.Dataset): - raise TypeError("`dataset` must be a `tf.data.Dataset` object.") - if (dataset.output_types != dtypes.string or - dataset.output_shapes != tensor_shape.scalar()): - raise TypeError( - "`dataset` must produce scalar `DT_STRING` tensors whereas it " - "produces shape {0} and types {1}".format(dataset.output_shapes, - dataset.output_types)) - return gen_dataset_ops.dataset_to_tf_record( - dataset._as_variant_tensor(), self._filename, self._compression_type) # pylint: disable=protected-access + super(TFRecordWriter, self).__init__(filename, compression_type) diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md index 2e025765e4aaab7114aa6e3e79336e48a71b5b55..f82453f3b5ea01b8bb64a70bd49f5e3e831bb4e2 100644 --- a/tensorflow/contrib/distribute/README.md +++ b/tensorflow/contrib/distribute/README.md @@ -20,7 +20,7 @@ on many GPUs on one machine. Essentially, we create copies of all variables in the model's layers on each device. We then use all-reduce to combine gradients across the devices before applying them to the variables to keep them in sync. * [`CollectiveAllReduceStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/CollectiveAllReduceStrategy): -This is a version of `MirroredStrategy` for multi-working training. It uses +This is a version of `MirroredStrategy` for multi-worker training. It uses a collective op to do all-reduce. This supports between-graph communication and synchronization, and delegates the specifics of the all-reduce implementation to the runtime (as opposed to encoding it in the graph). This allows it to perform @@ -31,8 +31,8 @@ fault-tolerance to allow training to continue when there is worker failure. * [`ParameterServerStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/ParameterServerStrategy): This strategy supports using parameter servers either for multi-GPU local training or asynchronous multi-machine training. When used to train locally, -variables are not mirrored, instead they placed on the CPU and operations are -replicated across all local GPUs. In a multi-machine setting, some are +variables are not mirrored, instead they are placed on the CPU and operations +are replicated across all local GPUs. In a multi-machine setting, some are designated as workers and some as parameter servers. Each variable is placed on one parameter server. Computation operations are replicated across all GPUs of the workers. @@ -190,7 +190,7 @@ in the input function gives a solid boost in performance. When using For multi-worker training, no code change is required to the `Estimator` code. You can run the same model code for all tasks in your cluster including parameter servers and the evaluator. But you need to use -`tf.estimator.train_and_evaluator`, explicitly specify `num_gpus_per_workers` +`tf.estimator.train_and_evaluate`, explicitly specify `num_gpus_per_workers` for your strategy object, and set "TF\_CONFIG" environment variables for each binary running in your cluster. We'll provide a Kubernetes template in the [tensorflow/ecosystem](https://github.com/tensorflow/ecosystem) repo which sets diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index cfb9d42a6fc472ca958e883aa0b70b42825d80da..dc2964568b529e4bf73dc648493d03db3cc80c1a 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -22,7 +22,6 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ ":input_ops", - ":prefetching_ops_v2", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:device_util", @@ -31,6 +30,7 @@ py_library( "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python/data/ops:multi_device_iterator_ops", "//tensorflow/python/eager:context", "//tensorflow/python/training/checkpointable:base", "@six_archive//:six", @@ -411,6 +411,24 @@ cuda_py_test( ], ) +cuda_py_test( + name = "moving_averages_test", + srcs = ["moving_averages_test.py"], + additional_deps = [ + ":combinations", + "@absl_py//absl/testing:parameterized", + "//tensorflow/python/eager:test", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:training", + "//tensorflow/python:variables", + ], + tags = [ + "no_pip", + ], +) + cuda_py_test( name = "optimizer_v2_test", srcs = ["optimizer_v2_test.py"], @@ -648,32 +666,6 @@ cuda_py_test( ], ) -py_library( - name = "prefetching_ops_v2", - srcs = ["prefetching_ops_v2.py"], - deps = [ - "//tensorflow/contrib/data/python/ops:prefetching_ops", - "//tensorflow/python:experimental_dataset_ops_gen", - "//tensorflow/python:framework_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", - ], -) - -cuda_py_test( - name = "prefetching_ops_v2_test", - srcs = ["prefetching_ops_v2_test.py"], - additional_deps = [ - ":prefetching_ops_v2", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - ], -) - py_library( name = "input_ops", srcs = ["input_ops.py"], @@ -728,6 +720,7 @@ cuda_py_test( additional_deps = [ ":keras_test_lib", ], + shard_count = 16, tags = [ "multi_and_single_gpu", "no_pip", @@ -736,18 +729,27 @@ cuda_py_test( ], ) -cuda_py_test( - name = "metrics_v1_test", +py_library( + name = "metrics_v1_test_lib", + testonly = 1, srcs = ["metrics_v1_test.py"], - additional_deps = [ + deps = [ ":combinations", - "@absl_py//absl/testing:parameterized", "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/python:math_ops", "//tensorflow/python:metrics", "//tensorflow/python:variables", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/eager:test", + "@absl_py//absl/testing:parameterized", + ], +) + +cuda_py_test( + name = "metrics_v1_test", + srcs = ["metrics_v1_test.py"], + additional_deps = [ + ":metrics_v1_test_lib", ], tags = [ "multi_and_single_gpu", diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index 33ffbf6abe1ef0cbd018d96d505482de8c76fd11..6796a23d464d344554ae9654e0992e30df5ad213 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -128,7 +128,8 @@ class CollectiveAllReduceStrategyTestBase( # TODO(yuefengz): support non-Mirrored variable as destinations. g = d.reduce( variable_scope.VariableAggregation.SUM, g, destinations=v) - with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + with ops.control_dependencies( + d.update(v, update, g, grouped=False)): after_list.append(d.read_var(v)) return before_list, after_list diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 82ca041cc210a1334da4ea62985f5c467761ddd4..63a163e76cdd99c73399c657cbe9bc3d010369d2 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -329,10 +329,10 @@ one_device_strategy = NamedDistribution( required_gpus=None) tpu_strategy = NamedDistribution( "TPU", lambda: tpu_lib.TPUStrategy( - TPUClusterResolver(""), steps_per_run=5), + TPUClusterResolver(""), steps_per_run=2), required_tpu=True) tpu_strategy_one_step = NamedDistribution( - "TPU", lambda: tpu_lib.TPUStrategy( + "TPUOneStep", lambda: tpu_lib.TPUStrategy( TPUClusterResolver(""), steps_per_run=1), required_tpu=True) # Note that we disable prefetching for testing since prefetching makes @@ -349,26 +349,26 @@ mirrored_strategy_with_two_gpus = NamedDistribution( required_gpus=2) -adam_optimizer_v1_fn = NamedObject( - "AdamV1", lambda: adam.AdamOptimizer(0.001, epsilon=1)) gradient_descent_optimizer_v1_fn = NamedObject( "GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2)) adagrad_optimizer_v1_fn = NamedObject( "AdagradV1", lambda: adagrad.AdagradOptimizer(0.001)) +adam_optimizer_v1_fn = NamedObject("AdamV1", + lambda: adam.AdamOptimizer(0.001, epsilon=1)) rmsprop_optimizer_v1_fn = NamedObject( "RmsPropV1", lambda: rmsprop.RMSPropOptimizer(0.001)) -optimizers_v1 = [adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn, - adagrad_optimizer_v1_fn] -adam_optimizer_v2_fn = NamedObject( - "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1)) +optimizers_v1 = [gradient_descent_optimizer_v1_fn, adagrad_optimizer_v1_fn] + gradient_descent_optimizer_v2_fn = NamedObject( "GradientDescentV2", lambda: gradient_descent_v2.GradientDescentOptimizer(0.2)) adagrad_optimizer_v2_fn = NamedObject( "AdagradV2", lambda: adagrad_v2.AdagradOptimizer(0.001)) -optimizers_v2 = [adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn, - adagrad_optimizer_v2_fn] +adam_optimizer_v2_fn = NamedObject( + "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1)) + +optimizers_v2 = [gradient_descent_optimizer_v2_fn, adagrad_optimizer_v2_fn] graph_and_eager_modes = ["graph", "eager"] diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py index a84ef041960e389c08246fc8a16df2300856d968..da7f8c548f94972b6ec0a67848e1520386d1e28b 100644 --- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py +++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py @@ -113,7 +113,7 @@ def main(_): distribute=strategy) # Train the model with the train dataset. - model.fit(x=train_ds, epochs=20, steps_per_epoch=310) + model.fit(x=train_ds, epochs=20, steps_per_epoch=468) # Evaluate the model with the eval dataset. score = model.evaluate(eval_ds, steps=10, verbose=0) diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index 3aab2c521fde480824a52f435801e2ec19b88385..dfa38912897359036d5c333b69f54047f52a2f49 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -189,6 +189,14 @@ def get_dataset(distribution): return dataset +def get_predict_dataset(distribution): + inputs = np.zeros((10, 3), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices(inputs) + dataset = dataset.repeat(100) + dataset = batch_wrapper(dataset, 10, distribution) + return dataset + + strategies = [combinations.default_strategy, combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, @@ -347,56 +355,104 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): gfile.DeleteRecursively(self._config.model_dir) -class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): +class TestDistributionStrategyWithNumpyArrays(test.TestCase, + parameterized.TestCase): - def test_validating_dataset_input_tensors_with_shape_mismatch(self): + @combinations.generate(strategy_combinations()) + def test_creating_var_with_numpy_arrays(self, distribution): with self.cached_session(): + x = np.asarray(np.random.random((64, 3)), dtype=np.float32) + var_x = distributed_training_utils.get_var_for_numpy(distribution, x) + val = self.evaluate(var_x.value()) + # Verify that the numpy value is copied to the variable. + self.assertAllEqual(x, val) + + def test_calculating_batch_params(self): + # This verifies that we calculate the number of steps when the batch size + # is specified. + with self.cached_session(): + # 64 is the number of input samples. + inputs = np.zeros((64, 3), dtype=np.float32) + # The number of towers is equal to 3. + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', + '/device:CPU:0', + '/device:GPU:1']) + + with self.assertRaisesRegexp(ValueError, 'Please specify a batch_size ' + 'that is smaller than'): + # The batch size(128) is larger than the number of input + # samples(64). + distributed_training_utils.get_input_batch_params(inputs, + 128, + strategy) + + with self.assertRaisesRegexp(ValueError, 'is smaller than the number ' + 'of towers'): + # The batch size(32) * num_towers(3) is 96 which is greater than the + # number of input samples(64). + distributed_training_utils.get_input_batch_params(inputs, + 32, + strategy) + + # The number of towers now is equal to 2. strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', '/device:CPU:0']) - a = constant_op.constant([1, 2], shape=(1, 2)) - b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2)) - x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) - y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) - with strategy.scope(): - # Removed device and input tensor shape details from the error message - # since the order of the device and the corresponding input tensor shape - # is not deterministic over different runs. - with self.assertRaisesRegexp(ValueError, - 'Input tensor shapes do not match for ' - 'distributed tensor inputs ' - 'DistributedValues:.+'): - distributed_training_utils.validate_distributed_dataset_inputs( - strategy, x, y) - - def test_validating_dataset_input_tensors_with_dtype_mismatch(self): + # 32 is the batch size per tower. + steps = distributed_training_utils.get_input_batch_params(inputs, + 32, + strategy) + # The number of batches is the ratio of input samples(64) to + # batch size(32) which is 2. The number of steps(1) is the ratio of + # number of batches(2) to the number of towers(2). + self.assertEqual(steps, 1) + + # 16 is the batch size per tower. + steps = distributed_training_utils.get_input_batch_params(inputs, + 16, + strategy) + # The number of batches is the ratio of input samples(64) to + # batch size(16) which is 4. The number of steps(2) is the ratio of + # number of batches(4) to the number of towers(2). + self.assertEqual(steps, 2) + + def test_calculating_batch_size(self): with self.cached_session(): + # 64 is the number of input samples. + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 4), dtype=np.float32) + + model = get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', '/device:CPU:0']) - a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) - b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) - x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) - y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) - with strategy.scope(): - # Removed device and input tensor dtype details from the error message - # since the order of the device and the corresponding input tensor dtype - # is not deterministic over different runs. - with self.assertRaisesRegexp(ValueError, - 'Input tensor dtypes do not match for ' - 'distributed tensor inputs ' - 'DistributedValues:.+'): - distributed_training_utils.validate_distributed_dataset_inputs( - strategy, x, y) + strategy._require_static_shapes = True - def test_calling_model_with_numpy_arrays(self): + model.compile(optimizer, loss, distribute=strategy) + iterator = model._distribution_standardize_user_data(inputs, + targets, + batch_size=None, + check_steps=True, + steps_name='steps', + steps=3) + + # The global batch size(21) across all towers is the ratio of the input + # samples(64) to the steps(3). + # The batch size(10) per device is the ratio of the global batch size(21) + # to the number of towers(2). + # The global batch size and batch size are rounded integer values. + self.assertEqual(10, distributed_training_utils.get_batch_dimension( + iterator._iterator)) + + @combinations.generate(strategy_combinations()) + def test_calling_model_with_numpy_arrays(self, distribution): with self.cached_session(): model = get_model() optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' - metrics = ['mae', keras.metrics.CategoricalAccuracy()] - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', - '/device:GPU:0']) - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) inputs = np.zeros((64, 3), dtype=np.float32) targets = np.zeros((64, 4), dtype=np.float32) @@ -419,6 +475,52 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): # with batch_size model.predict(inputs, batch_size=8) + @combinations.generate(strategy_combinations()) + def test_calling_model_with_nested_numpy_arrays(self, distribution): + with self.cached_session(): + a = keras.layers.Input(shape=(3,), name='input_a') + b = keras.layers.Input(shape=(3,), name='input_b') + + dense = keras.layers.Dense(4, name='dense') + c = dense(a) + d = dense(b) + e = keras.layers.Dropout(0.5, name='dropout')(c) + + model = keras.models.Model([a, b], [d, e]) + + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + input_a_np = np.asarray(np.random.random((64, 3)), dtype=np.float32) + input_b_np = np.asarray(np.random.random((64, 3)), dtype=np.float32) + inputs = [input_a_np, input_b_np] + + output_d_np = np.asarray(np.random.random((64, 4)), dtype=np.float32) + output_e_np = np.asarray(np.random.random((64, 4)), dtype=np.float32) + targets = [output_d_np, output_e_np] + + # Call fit with validation data + model.fit(inputs, targets, epochs=1, batch_size=8, verbose=0) + + # TODO(anjalisridhar): We need tests for when the batch size and steps are + # smaller and results in a 0 batch_size and steps value. + model.evaluate(inputs, targets) + # with steps + model.evaluate(inputs, targets, steps=2) + # with batch_size + model.evaluate(inputs, targets, batch_size=8) + + model.predict(inputs) + # with steps + model.predict(inputs, steps=2) + # with batch_size + model.predict(inputs, batch_size=8) + + +class TestDistributionStrategyWithDatasets(test.TestCase, + parameterized.TestCase): + @combinations.generate(strategy_combinations()) def test_calling_model_on_same_dataset(self, distribution): with self.cached_session(): @@ -436,7 +538,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): validation_data=dataset, validation_steps=2) model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, validation_data=dataset, validation_steps=2) - model.predict(dataset, steps=2) + model.predict(get_predict_dataset(distribution), steps=2) # TODO(priyag): Enable this test for TPU. Currently tuples/dict don't work # as clone_model's input_tensors argument only seems to accept list and not @@ -496,10 +598,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) model.evaluate(dataset, steps=2, verbose=1) - model.predict(dataset, steps=2) - # Test with validation data - model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, - validation_data=dataset, validation_steps=2) + model.predict(get_predict_dataset(distribution), steps=2) @combinations.generate(strategy_and_optimizer_combinations()) def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer): @@ -513,7 +612,135 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) model.evaluate(dataset, steps=2, verbose=1) - model.predict(dataset, steps=2) + model.predict(get_predict_dataset(distribution), steps=2) + + def test_dataset_input_shape_validation(self): + with self.cached_session(): + model = get_model() + + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + + model.compile(optimizer, loss, distribute=strategy) + + # User forgets to batch the dataset + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + + with self.assertRaisesRegexp(ValueError, 'expected input to have shape'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + + # Wrong input shape + inputs = np.zeros((10, 5), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + with self.assertRaisesRegexp(ValueError, + 'expected input to have shape'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + + @combinations.generate(combinations.combine( + distribution=[combinations.tpu_strategy_one_step], + mode=['graph'])) + def test_dataset_input_shape_fully_defined(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + dataset = get_dataset(distribution) + # Input shapes are not fully known. Batch dimension is unknown as we are + # not using the drop_remainder argument. + dataset = dataset.repeat(100).batch(10) + + with self.assertRaisesRegexp(ValueError, 'requires fully defined shapes'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + + def test_learning_phase_value(self): + # TODO(anjalisridhar): Modify this test to use Lambdas since we can compare + # meaningful values. Currently we don't pass the learning phase if the + # Lambda layer uses the learning phase. + with self.cached_session(): + x = keras.layers.Input(shape=(1,), name='input') + y = keras.layers.Dense(1, kernel_initializer='ones')(x) + z = keras.layers.Dropout(0.9999)(y) + model = keras.Model(x, z) + initial_weights = model.get_weights() + + optimizer = gradient_descent.GradientDescentOptimizer(0.005) + loss = 'mse' + metrics = ['acc'] + strategy = mirrored_strategy.MirroredStrategy( + ['/device:GPU:0', '/device:GPU:1']) + + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + inputs = np.ones((10, 1), dtype=np.float32) + targets = np.ones((10, 1), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat().batch(8) + hist = model.fit(dataset, epochs=1, steps_per_epoch=20, verbose=1) + self.assertAlmostEqual(hist.history['acc'][0], 0, 0) + + model.set_weights(initial_weights) + evaluate_output = model.evaluate(dataset, steps=20) + self.assertAlmostEqual(evaluate_output[1], 1, 0) + + inputs = np.ones((10, 1), dtype=np.float32) + predict_dataset = dataset_ops.Dataset.from_tensor_slices(inputs) + predict_dataset = predict_dataset.repeat().batch(5) + output = model.predict(predict_dataset, steps=10) + ref_output = np.ones((50, 1), dtype=np.float32) + self.assertArrayNear(output[0], ref_output, 1e-1) + + +class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): + + def test_validating_dataset_input_tensors_with_shape_mismatch(self): + with self.cached_session(): + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', + '/device:CPU:0']) + a = constant_op.constant([1, 2], shape=(1, 2)) + b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2)) + x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) + y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) + with strategy.scope(): + # Removed device and input tensor shape details from the error message + # since the order of the device and the corresponding input tensor shape + # is not deterministic over different runs. + with self.assertRaisesRegexp(ValueError, + 'Input tensor shapes do not match for ' + 'distributed tensor inputs ' + 'DistributedValues:.+'): + distributed_training_utils.validate_distributed_dataset_inputs( + strategy, x, y) + + def test_validating_dataset_input_tensors_with_dtype_mismatch(self): + with self.cached_session(): + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', + '/device:CPU:0']) + a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) + b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) + x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) + y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) + with strategy.scope(): + # Removed device and input tensor dtype details from the error message + # since the order of the device and the corresponding input tensor dtype + # is not deterministic over different runs. + with self.assertRaisesRegexp(ValueError, + 'Input tensor dtypes do not match for ' + 'distributed tensor inputs ' + 'DistributedValues:.+'): + distributed_training_utils.validate_distributed_dataset_inputs( + strategy, x, y) def test_unsupported_features(self): with self.cached_session(): @@ -595,91 +822,8 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, callbacks=[keras.callbacks.TensorBoard(histogram_freq=10)]) - def test_dataset_input_shape_validation(self): - with self.cached_session(): - model = get_model() - - optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) - loss = 'mse' - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', - '/device:GPU:0']) - - model.compile(optimizer, loss, distribute=strategy) - - # User forgets to batch the dataset - inputs = np.zeros((10, 3), dtype=np.float32) - targets = np.zeros((10, 4), dtype=np.float32) - dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) - dataset = dataset.repeat(100) - - with self.assertRaisesRegexp(ValueError, 'expected input to have shape'): - model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) - - # Wrong input shape - inputs = np.zeros((10, 5), dtype=np.float32) - targets = np.zeros((10, 4), dtype=np.float32) - dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) - dataset = dataset.repeat(100) - dataset = dataset.batch(10) - - with self.assertRaisesRegexp(ValueError, - 'expected input to have shape'): - model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) - - @combinations.generate(combinations.combine( - distribution=[combinations.tpu_strategy_one_step], - mode=['graph'])) - def test_dataset_input_shape_fully_defined(self, distribution): - with self.cached_session(): - model = get_model() - - optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) - - dataset = get_dataset(distribution) - # Input shapes are not fully known. Batch dimension is unknown as we are - # not using the drop_remainder argument. - dataset = dataset.repeat(100).batch(10) - - with self.assertRaisesRegexp(ValueError, 'requires fully defined shapes'): - model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) - def test_learning_phase_value(self): - # TODO(anjalisridhar): Modify this test to use Lambdas since we can compare - # meaningful values. Currently we don't pass the learning phase if the - # Lambda layer uses the learning phase. - with self.cached_session(): - x = keras.layers.Input(shape=(16,), name='input') - y = keras.layers.Dense(16)(x) - z = keras.layers.Dropout(0.9999)(y) - model = keras.Model(x, z) - - optimizer = gradient_descent.GradientDescentOptimizer(0.005) - loss = 'mse' - metrics = ['acc'] - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) - - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) - - inputs = np.random.rand(10, 16) - targets = np.ones((10, 16), dtype=np.float32) - dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) - dataset = dataset.repeat(100) - dataset = dataset.batch(8) - - hist = model.fit(dataset, epochs=5, steps_per_epoch=20, verbose=1) - self.assertEqual(hist.history['acc'][0], 1) - - evaluate_output = model.evaluate(dataset, steps=20) - self.assertEqual(evaluate_output[1], 0) - - predict_output = model.predict(dataset, steps=1) - self.assertNotEqual(np.mean(predict_output), 0) - - -class LossMaskingWithDistributionStrategyTest(test.TestCase): +class TestDistributionStrategyWithLossMasking(test.TestCase): # TODO(priyag): Enable all strategies for this test. Currently it does not # work for TPU due to some invalid datatype. @@ -706,7 +850,7 @@ class LossMaskingWithDistributionStrategyTest(test.TestCase): self.assertEqual(hist.history['loss'][0], 0) -class NormalizationLayerWithDistributionStrategyTest( +class TestDistributionStrategyWithNormalizationLayer( test.TestCase, parameterized.TestCase): @combinations.generate(strategy_combinations()) @@ -726,16 +870,20 @@ class NormalizationLayerWithDistributionStrategyTest( dataset = dataset.repeat(100) dataset = batch_wrapper(dataset, 32, distribution) + predict_dataset = dataset_ops.Dataset.from_tensor_slices(x) + predict_dataset = predict_dataset.repeat(100) + predict_dataset = batch_wrapper(predict_dataset, 32, distribution) + model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10) - out = model.predict(dataset, steps=2) + out = model.predict(predict_dataset, steps=2) out -= keras.backend.eval(norm.beta) out /= keras.backend.eval(norm.gamma) np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1) np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) -class CorrectnessWithDistributionStrategyTest(test.TestCase, - parameterized.TestCase): +class TestDistributionStrategyCorrectness(test.TestCase, + parameterized.TestCase): @combinations.generate(strategy_combinations()) def test_metric_correctness(self, distribution): @@ -811,8 +959,7 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase, predict_batch_size = 4 if with_distribution: predict_batch_size //= with_distribution.num_towers - predict_dataset = dataset_ops.Dataset.from_tensor_slices((x_predict, - x_predict)) + predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) predict_dataset = batch_wrapper(predict_dataset, predict_batch_size, distribution) predict_result = model.predict(predict_dataset, steps=1) diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py index 8163494c8ed2c5c2164df2e731d09ebb794414cd..2c79a8bfd3cf5a70c6a940b19aa3b7268ce7d524 100644 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import tpu_strategy from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import test from tensorflow.python.framework import ops @@ -35,7 +36,8 @@ def _labeled_dataset_fn(): # 8: 3, 2 -> False; 9: 4, 0 -> False; 10: 0, 1 -> False; 11: 1, 2 -> False # 12: 2, 0 -> False; 13: 3, 1 -> False; 14: 4, 2 -> False; 15: 0, 0 -> True return dataset_ops.Dataset.range(1000).map( - lambda x: {"labels": x % 5, "predictions": x % 3}).batch(4) + lambda x: {"labels": x % 5, "predictions": x % 3}).batch( + 4, drop_remainder=True) def _boolean_dataset_fn(): @@ -47,7 +49,8 @@ def _boolean_dataset_fn(): # F, T -> FP; T, F -> FN; F, F -> TN return dataset_ops.Dataset.from_tensor_slices({ "labels": [True, False, True, False], - "predictions": [True, True, False, False]}).repeat().batch(3) + "predictions": [True, True, False, False]}).repeat().batch( + 3, drop_remainder=True) def _threshold_dataset_fn(): @@ -59,7 +62,8 @@ def _threshold_dataset_fn(): # False, .75 -> FP; True, .25 -> FN; False, 0.0 -> TN return dataset_ops.Dataset.from_tensor_slices({ "labels": [True, False, True, False], - "predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch(3) + "predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch( + 3, drop_remainder=True) def _regression_dataset_fn(): @@ -79,6 +83,12 @@ def all_combinations(): mode=["graph"]) +def tpu_combinations(): + return combinations.combine(distribution=[combinations.tpu_strategy_one_step, + combinations.tpu_strategy], + mode=["graph"]) + + # TODO(josh11b): Test metrics.recall_at_top_k, metrics.average_precision_at_k, # metrics.precision_at_k class MetricsV1Test(test.TestCase, parameterized.TestCase): @@ -86,43 +96,52 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn): with ops.Graph().as_default(), distribution.scope(): iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() - value, update = distribution.call_for_each_tower( - metric_fn, iterator.get_next()) - update = distribution.group(update) + dataset_fn).make_initializable_iterator() + if isinstance(distribution, tpu_strategy.TPUStrategy): + def step_fn(ctx, inputs): + value, update = distribution.call_for_each_tower( + metric_fn, inputs) + ctx.set_non_tensor_output(name="value", output=value) + return distribution.group(update) + + ctx = distribution.run_steps_on_dataset( + step_fn, iterator, iterations=distribution.steps_per_run) + update = ctx.run_op + value = ctx.non_tensor_outputs["value"] + # In each run, we run multiple steps, and each steps consumes as many + # batches as number of towers. + batches_per_update = ( + distribution.num_towers * distribution.steps_per_run) + else: + value, update = distribution.call_for_each_tower( + metric_fn, iterator.get_next()) + update = distribution.group(update) + # TODO(josh11b): Once we switch to using a global batch size for input, + # replace "distribution.num_towers" with "1". + batches_per_update = distribution.num_towers + + self.evaluate(iterator.initializer) + self.evaluate(distribution.initialize()) self.evaluate(variables.local_variables_initializer()) - # TODO(josh11b): Once we switch to using a global batch size for input, - # replace "distribution.num_towers" with "1". - batches_per_update = distribution.num_towers - - # Update variables using the first `num_towers` batches. - self.evaluate(update) - self.assertAllClose(expected_fn(batches_per_update), self.evaluate(value), - 0.001, msg="After first update") - - # Update variables using the second `num_towers` batches. - self.evaluate(update) - self.assertAllClose(expected_fn(2 * batches_per_update), - self.evaluate(value), - 0.001, - msg="After second update") - - if batches_per_update == 1: # Consume 4 input batches - self.evaluate(update) - self.assertAllClose(expected_fn(3 * batches_per_update), - self.evaluate(value), - 0.001, - msg="After third update") + + batches_consumed = 0 + for i in range(4): self.evaluate(update) - self.assertAllClose(expected_fn(4 * batches_per_update), + batches_consumed += batches_per_update + self.assertAllClose(expected_fn(batches_consumed), self.evaluate(value), 0.001, - msg="After fourth update") + msg="After update #" + str(i+1)) + if batches_consumed >= 4: # Consume 4 input batches in total. + break - @combinations.generate(all_combinations()) + self.evaluate(distribution.finalize()) + + @combinations.generate(all_combinations() + tpu_combinations()) def testMean(self, distribution): def _dataset_fn(): - return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch(4) + return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch( + 4, drop_remainder=True) def _expected_fn(num_batches): # Mean(0..3) = 1.5, Mean(0..7) = 3.5, Mean(0..11) = 5.5, etc. @@ -130,7 +149,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric(distribution, _dataset_fn, metrics.mean, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testAccuracy(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -143,6 +162,8 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + # TODO(priyag, jhseu): Enable TPU for this test once scatter_add is added + # for TPUMirroredVariable. @combinations.generate(all_combinations()) def testMeanPerClassAccuracy(self, distribution): def _metric_fn(x): @@ -161,6 +182,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + # NOTE(priyag): This metric doesn't work on TPUs yet. @combinations.generate(all_combinations()) def testMeanIOU(self, distribution): def _metric_fn(x): @@ -179,7 +201,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testMeanTensor(self, distribution): def _dataset_fn(): dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float) @@ -198,7 +220,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _dataset_fn, metrics.mean_tensor, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testAUCROC(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -212,7 +234,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testAUCPR(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -226,7 +248,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testFalseNegatives(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -239,7 +261,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testFalseNegativesAtThresholds(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -252,7 +274,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testTrueNegatives(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -265,7 +287,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testTrueNegativesAtThresholds(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -278,7 +300,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testFalsePositives(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -291,7 +313,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testFalsePositivesAtThresholds(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -304,7 +326,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testTruePositives(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -317,7 +339,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testTruePositivesAtThresholds(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -330,7 +352,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testPrecision(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -343,7 +365,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testPrecisionAtThreshold(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -356,7 +378,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testRecall(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -369,7 +391,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testRecallAtThreshold(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -382,7 +404,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testMeanSquaredError(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -395,7 +417,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _regression_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testRootMeanSquaredError(self, distribution): def _metric_fn(x): labels = x["labels"] diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index ba147e78241e5ab45809e498e00debd45a2c49b4..3c4544a39ef85c18a34216ba7f3ac65b45216003 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -41,6 +41,14 @@ from tensorflow.python.ops.losses import losses_impl class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): + def _get_iterator(self, ds): + if context.executing_eagerly(): + iterator = ds.make_one_shot_iterator() + else: + iterator = ds.make_initializable_iterator() + self.evaluate(iterator.initializer) + return iterator + @combinations.generate( combinations.times( combinations.distributions_and_v1_optimizers(), @@ -62,8 +70,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): distribution.call_for_each_tower( model_fn, *inputs, run_concurrently=layer.built)) - iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() + iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) def run_step(): return distribution.run_steps_on_dataset( @@ -99,8 +106,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() + iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) def run_step(): return distribution.group( @@ -159,8 +165,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): distribution.call_for_each_tower( model_fn, *inputs, run_concurrently=layer.built)) - iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() + iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) def run_step(): return distribution.run_steps_on_dataset( @@ -179,11 +184,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def get_expected_variables(optimizer_fn, num_parameter_devices): variables_map = { "GradientDescent": ["dense/kernel", "dense/bias"], - "Adam": [ - "dense/kernel", "dense/bias", "beta1_power", "beta2_power", - "dense/kernel/Adam", "dense/kernel/Adam_1", "dense/bias/Adam", - "dense/bias/Adam_1" - ], "Adagrad": [ "dense/kernel/Adagrad", "dense/kernel", "dense/bias/Adagrad", "dense/bias" @@ -244,8 +244,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) return control_flow_ops.group(fetches) - iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() + iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) def run_step(): return distribution.run_steps_on_dataset( @@ -338,8 +337,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): distribution.call_for_each_tower( model_fn, x, y, run_concurrently=False)) - iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() + iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) def run_step(): return distribution.run_steps_on_dataset( @@ -432,8 +430,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): output=loss) return distribution.group(train_op) - iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() + iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) def run_step(): initial_loss = lambda: constant_op.constant(1e7) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 4d7516063cafcd552399da1489a146592081be8f..0f82508428a58fb671cef25c97ca5880ebb38e83 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -293,7 +293,8 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) for v in index.values(): - l.remove(v) + if v in l: + l.remove(v) g.add_to_collections(collections, result) elif ops.GraphKeys.GLOBAL_STEP in collections: ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result) @@ -318,12 +319,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed). The distribution strategy inherits these concepts as well and in addition to that we also clarify several more concepts: - * **In-graph replication**: the `client` creates a single `tf.Graph` that + + * **In-graph replication**: the `client` creates a single `tf.Graph` that specifies tasks for devices on all workers. The `client` then creates a client session which will talk to the `master` service of a `worker`. Then the `master` will partition the graph and distribute the work to all participating workers. - * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one + * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one physical machine. We will have multiple `worker`s with different `task` index. They all do similar things except for one worker checkpointing model variables, writing summaries, etc. in addition to its ordinary work. @@ -460,16 +462,20 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Initialize replicas with the same value: - if context.executing_eagerly(): - kwargs["initial_value"] = array_ops.identity( - index[devices[0]].value()) - else: - def initial_value_fn(device=d): + def initial_value_fn(device=d): + if context.executing_eagerly(): + init_value = index[devices[0]].value() + return array_ops.identity(init_value) + else: with ops.device(device): - return array_ops.identity(index[devices[0]].initial_value) - kwargs["initial_value"] = initial_value_fn + init_value = index[devices[0]].initial_value + return array_ops.identity(init_value) + kwargs["initial_value"] = initial_value_fn with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): - v = next_creator(*args, **kwargs) + # Don't record operations (e.g. other variable reads) during + # variable creation. + with tape.stop_recording(): + v = next_creator(*args, **kwargs) assert not isinstance(v, values.DistributedVariable) index[d] = v return index @@ -627,9 +633,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): return self._get_cross_tower_ops().batch_reduce(aggregation, value_destination_pairs) - def _update(self, var, fn, *args, **kwargs): + def _update(self, var, options, fn, *args, **kwargs): # TODO(josh11b): In eager mode, use one thread per device. assert isinstance(var, values.DistributedVariable) + should_group = options.pop("grouped") + assert not options # Validate that we are processing all of the options. updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) @@ -638,10 +646,12 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) - return values.regroup(updates, values.Mirrored) + return values.update_regroup(self, updates, should_group) - def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): assert isinstance(colocate_with, list) + should_group = options.pop("grouped") + assert not options # Validate that we are processing all of the options. # TODO(josh11b): In eager mode, use one thread per device. updates = {} for d in colocate_with: @@ -649,7 +659,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): updates[d] = fn(*values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) - return values.regroup(updates, values.Mirrored) + return values.update_regroup(self, updates, should_group) def read_var(self, tower_local_var): """Read the aggregate value of a tower-local variable.""" diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index f51e543624d23e378c3a44cb8bce956c71b6e40d..fd833c772d49ed0fad6745a66065c155f4570395 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -20,6 +20,8 @@ from __future__ import print_function import sys +import numpy as np + from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import strategy_test_lib @@ -34,7 +36,10 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.keras.engine import training as keras_training +from tensorflow.python.keras.layers import core as keras_core from tensorflow.python.layers import core +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell_impl @@ -43,6 +48,8 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import device_util from tensorflow.python.training import distribution_strategy_context +from tensorflow.python.training import gradient_descent +from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.training import server_lib @@ -300,9 +307,15 @@ class MirroredStrategyVariableCreationTest(test.TestCase): dist = mirrored_strategy.MirroredStrategy( ["/device:GPU:0", "/device:CPU:0"]) - features = dist.distribute_dataset( - lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10) - ).make_one_shot_iterator().get_next() + ds = dist.distribute_dataset( + lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)) + if context.executing_eagerly(): + iterator = ds.make_one_shot_iterator() + else: + iterator = ds.make_initializable_iterator() + self.evaluate([iterator.initializer]) + + features = iterator.get_next() with dist.scope(): result = dist.call_for_each_tower( @@ -826,7 +839,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with dist.scope(): ret_v_sum = dist.call_for_each_tower(model_fn, run_concurrently=False) - update_ops = dist.unwrap(dist.update(ret_v_sum, update, 5.0)) + update_ops = dist.update(ret_v_sum, update, 5.0, grouped=False) # Initialize variables. self.evaluate(variables.global_variables_initializer()) @@ -1245,6 +1258,22 @@ class MockModel(object): return x +class MiniModel(keras_training.Model): + """Minimal model for mnist. + + Useful for testing and debugging on slow TPU simulators. + """ + + def __init__(self): + super(MiniModel, self).__init__(name="") + self.fc = keras_core.Dense(1, name="fc", kernel_initializer="ones", + bias_initializer="ones") + + def call(self, inputs, training=True): + inputs = array_ops.ones([1, 10]) + return self.fc(inputs) + + class MirroredStrategyDefunTest(test.TestCase): def _skip_eager_if_gpus_less_than(self, num_gpus): @@ -1365,6 +1394,41 @@ class MirroredStrategyDefunTest(test.TestCase): "GPU:0": 3.0 * 1.25}) self._call_and_check(fn1, [factors], expected_result, [fn1]) + @test_util.run_in_graph_and_eager_modes() + def testTrain(self): + self._skip_eager_if_gpus_less_than(1) + + cpu_dev = device_util.canonicalize("CPU:0") + gpu_dev = device_util.canonicalize("GPU:0") + devices = [cpu_dev, gpu_dev] + dist = mirrored_strategy.MirroredStrategy(devices) + + with dist.scope(): + mock_model = MiniModel() + mock_model.call = function.defun(mock_model.call) + + def loss_fn(ctx): + del ctx + return mock_model(array_ops.ones([1, 10])) + + gradients_fn = backprop.implicit_grad(loss_fn) + gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) + grads_and_vars = dist.call_for_each_tower( + gradients_fn, None, run_concurrently=False) + + optimizer = gradient_descent.GradientDescentOptimizer(0.25) + update_ops = optimizer._distributed_apply(dist, grads_and_vars) # pylint: disable=protected-access + + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + self.evaluate(update_ops) + + updated_var_values = self.evaluate(mock_model.variables) + # All variables start at 1.0 and get two updates of 0.25. + self.assertAllEqual(0.5 * np.ones([10, 1]), updated_var_values[0]) + self.assertAllEqual([0.5], updated_var_values[1]) + + class MultiWorkerMirroredStrategyTest( multi_worker_test_base.MultiWorkerTestBase, diff --git a/tensorflow/contrib/distribute/python/monitor.py b/tensorflow/contrib/distribute/python/monitor.py index 7644acedc99361d7287a91832d76bc68cbc6ac0a..17b7ab74f63f42e1ee14a82d3bffdd1df9b25857 100644 --- a/tensorflow/contrib/distribute/python/monitor.py +++ b/tensorflow/contrib/distribute/python/monitor.py @@ -51,6 +51,7 @@ class Monitor(object): else: if session is None: raise ValueError("Should provide a `session` in Graph mode.") + session.run(step_callable._iterator.initializer) # pylint: disable=protected-access self._run_step = session.make_callable(step_callable()) session.run(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/moving_averages_test.py b/tensorflow/contrib/distribute/python/moving_averages_test.py new file mode 100644 index 0000000000000000000000000000000000000000..119352ad9195dc51201863f34aef19cb3289e635 --- /dev/null +++ b/tensorflow/contrib/distribute/python/moving_averages_test.py @@ -0,0 +1,141 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for training.moving_averages when using a DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.training import moving_averages + + +all_combinations = combinations.combine( + distribution=[combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu], + mode=["graph"]) + + +class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): + + @combinations.generate(all_combinations) + def testTowerModeWithoutZeroDebias(self, distribution): + tower_id = [0] + + def tower_fn(): + var = variables.Variable([10.0, 11.0]) + val = constant_op.constant([1.0 + tower_id[0], 2.0 - tower_id[0]]) + tower_id[0] += 1 + decay = 0.25 + assign = moving_averages.assign_moving_average( + var, val, decay, zero_debias=False) + return var, assign + + with distribution.scope(), self.cached_session() as sess: + var, assign = distribution.call_for_each_tower(tower_fn) + variables.global_variables_initializer().run() + self.assertAllClose([10.0, 11.0], var.eval()) + sess.run(distribution.unwrap(assign)) + # Mean of val across calls to tower_fn(). + average_val = [1.0 + 0.5 * (tower_id[0] - 1), + 2.0 - 0.5 * (tower_id[0] - 1)] + val_weight = 1.0 - 0.25 + self.assertAllClose( + [10.0 * 0.25 + average_val[0] * val_weight, + 11.0 * 0.25 + average_val[1] * val_weight], + var.eval()) + + @combinations.generate(all_combinations) + def testTowerMode(self, distribution): + tower_id = [0] + + def tower_fn(): + var = variables.Variable([0.0, 0.0]) + val = constant_op.constant([1.0 + tower_id[0], 2.0 - tower_id[0]]) + tower_id[0] += 1 + decay = 0.25 + assign = moving_averages.assign_moving_average(var, val, decay) + return var, assign.op + + with distribution.scope(), self.cached_session() as sess: + var, assign_op = distribution.call_for_each_tower(tower_fn) + variables.global_variables_initializer().run() + self.assertAllClose([0.0, 0.0], var.eval()) + sess.run(distribution.unwrap(assign_op)) + # Mean of val across calls to tower_fn(). + average_val = [1.0 + 0.5 * (tower_id[0] - 1), + 2.0 - 0.5 * (tower_id[0] - 1)] + self.assertAllClose(average_val, var.eval()) + + @combinations.generate(all_combinations) + def testCrossTowerWithoutZeroDebias(self, distribution): + with distribution.scope(), self.cached_session() as sess: + var = variables.Variable([10.0, 11.0]) + val = constant_op.constant([1.0, 2.0]) + decay = 0.25 + # NOTE(josh11b): We currently generate an error if val is a PerDevice value. + assign = moving_averages.assign_moving_average( + var, val, decay, zero_debias=False) + + variables.global_variables_initializer().run() + self.assertAllClose([10.0, 11.0], var.eval()) + sess.run(assign) + average_val = [1.0, 2.0] + val_weight = 1.0 - 0.25 + self.assertAllClose( + [10.0 * 0.25 + average_val[0] * val_weight, + 11.0 * 0.25 + average_val[1] * val_weight], + var.eval()) + # Also try assign.op. + sess.run(assign.op) + orig_weight = 0.25 * 0.25 + val_weight = 1.0 - orig_weight + self.assertAllClose( + [10.0 * orig_weight + average_val[0] * val_weight, + 11.0 * orig_weight + average_val[1] * val_weight], + var.eval()) + + @combinations.generate(all_combinations) + def testCrossTower(self, distribution): + with distribution.scope(), self.cached_session() as sess: + var = variables.Variable([0.0, 0.0]) + val = array_ops.placeholder(dtypes.float32) + decay = 0.25 + # NOTE(josh11b): We currently generate an error if val is a PerDevice value. + assign = moving_averages.assign_moving_average(var, val, decay) + + variables.global_variables_initializer().run() + self.assertAllClose([0.0, 0.0], var.eval()) + sess.run(assign, feed_dict={val: [1.0, 2.0]}) + self.assertAllClose([1.0, 2.0], var.eval()) + + # Also try assign.op. + sess.run(assign.op, feed_dict={val: [10.0, 0.0]}) + self.assertAllClose( + [(1.0 * 0.25 + 10.0) / (1.0 * 0.25 + 1.0), + (2.0 * 0.25 + 0.0) / (1.0 * 0.25 + 1.0)], + var.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 23b220f64b843a83aba3f9867b61415b70f19668..f5259190485e701c190beb49220caff743f8fdcb 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -141,14 +141,21 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): else: assert False - def _update(self, var, fn, *args, **kwargs): - with ops.device(self._device), distribute_lib.UpdateContext(self._device): - return fn(var, *args, **kwargs) + def _update(self, var, options, fn, *args, **kwargs): + # The implementations of _update() and _update_non_slot() are identical + # except _update() passes `var` as the first argument to `fn()`. + return self._update_non_slot(var, options, fn, var, *args, **kwargs) - def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): del colocate_with + should_group = options.pop("grouped") + assert not options # Validate that we are processing all of the options. with ops.device(self._device), distribute_lib.UpdateContext(self._device): - return fn(*args, **kwargs) + result = fn(*args, **kwargs) + if should_group: + return result + else: + return nest.map_structure(self._unwrap, result) def read_var(self, tower_local_var): """Read the aggregate value of a tower-local variable.""" diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py index 6e9ba37a198fc8038c086d2672251adfac30fdcf..3064433129865703f574ba79e843880fc4390e74 100644 --- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -42,8 +42,11 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() + ds = distribution.distribute_dataset(dataset_fn) + if context.executing_eagerly(): + iterator = ds.make_one_shot_iterator() + else: + iterator = ds.make_initializable_iterator() def run_step(): return control_flow_ops.group(distribution.unwrap( @@ -52,6 +55,7 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): if not context.executing_eagerly(): with self.cached_session() as sess: + sess.run(iterator.initializer) run_step = sess.make_callable(run_step()) self.evaluate(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index 1125d027f64420863386d4fbd9db5564a5847825..6ddd91507bf86e8b0cf710a2340fd61abcdebe71 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -343,21 +343,33 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): return nest.map_structure(_select_fn, structured) - def _update(self, var, fn, *args, **kwargs): + def _update(self, var, options, fn, *args, **kwargs): if isinstance(var, values.AggregatingVariable): var = var.get() if not isinstance(var, resource_variable_ops.ResourceVariable): raise ValueError( "You can not update `var` %r. It must be a Variable." % var) + should_group = options.pop("grouped") + assert not options # Validate that we are processing all of the options. with ops.colocate_with(var), distribute_lib.UpdateContext(var.device): - return fn(var, *self._select_single_value(args), - **self._select_single_value(kwargs)) + result = fn(var, *self._select_single_value(args), + **self._select_single_value(kwargs)) + if should_group: + return result + else: + return nest.map_structure(self._unwrap, result) # TODO(yuefengz): does it need to call _select_single_value? - def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): + should_group = options.pop("grouped") + assert not options # Validate that we are processing all of the options. with ops.device( colocate_with.device), distribute_lib.UpdateContext(colocate_with): - return fn(*args, **kwargs) + result = fn(*args, **kwargs) + if should_group: + return result + else: + return nest.map_structure(self._unwrap, result) def _unwrap(self, val): if isinstance(val, values.DistributedValues): diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index 12789e0bc9f1c89ef8d57c40a978e2bb9471997b..9c112e4f851b5e5e6f65c0bd9d9564420f8d4446 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -262,7 +262,9 @@ class ParameterServerStrategyTestBase( h = f + 1.0 self.assertEqual( device_util.canonicalize(u.device), tower_variable_device) - self.assertEqual(device_util.canonicalize(x.device), h.device) + self.assertEqual( + device_util.canonicalize(x.device), + device_util.canonicalize(h.device)) return y_add, z_add, f y, z, f = d.call_for_each_tower(model_fn) @@ -395,7 +397,8 @@ class ParameterServerStrategyTestBase( # TODO(yuefengz): support non-Mirrored variable as destinations. g = d.reduce( variable_scope.VariableAggregation.SUM, g, destinations=v) - with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + with ops.control_dependencies( + d.update(v, update, g, grouped=False)): after_list.append(d.read_var(v)) return before_list, after_list diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py deleted file mode 100644 index 8d949943b778df5f60c5f2b5107c28840440e85e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py +++ /dev/null @@ -1,232 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Extension of prefetching_ops to support more than one device.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import warnings - -from tensorflow.contrib.data.python.ops import prefetching_ops -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.data.util import nest as data_nest -from tensorflow.python.data.util import sparse -from tensorflow.python.eager import context -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function -from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops -from tensorflow.python.util import nest - - -# pylint: disable=protected-access -class _PrefetchToDeviceIterator(object): - """A replacement for `tf.data.Iterator` that prefetches to another device. - - Args: - input_dataset: The input dataset. - one_shot: If true, we make a one shot iterator that's already initialized. - devices: Devices on which to prefetch. - buffer_size: Size of the prefetching buffer. - shared_name: (Optional.) If non-empty, the returned iterator will be shared - under the given name across multiple sessions that share the same devices - (e.g. when using a remote server). Only used if one_shot is False. - - Returns: - An Iterator type object. - """ - - def __init__(self, - input_dataset, - one_shot, - devices, - buffer_size, - shared_name=None): - self._input_dataset = input_dataset - self._get_next_call_count = 0 - self._one_shot = one_shot - if shared_name is None: - shared_name = "" - self._devices = devices - - if self._one_shot: - self._input_iterator = input_dataset.make_one_shot_iterator() - else: - self._input_iterator = iterator_ops.Iterator.from_structure( - self._input_dataset.output_types, self._input_dataset.output_shapes, - shared_name, self._input_dataset.output_classes) - input_iterator_handle = self._input_iterator.string_handle() - - @function.Defun(dtypes.string) - def _prefetch_fn(handle): - """Prefetches one element from `input_iterator`.""" - remote_iterator = iterator_ops.Iterator.from_string_handle( - handle, self._input_iterator.output_types, - self._input_iterator.output_shapes, - self._input_iterator.output_classes) - ret = remote_iterator.get_next() - return nest.flatten(sparse.serialize_sparse_tensors(ret)) - - target_device = ged_ops.experimental_iterator_get_device( - self._input_iterator._iterator_resource) - self._buffering_resources = [] - for device in nest.flatten(self._devices): - with ops.device(device): - buffer_resource_handle = prefetching_ops.function_buffering_resource( - f=_prefetch_fn, - output_types=data_nest.flatten( - sparse.as_dense_types(self._input_dataset.output_types, - self._input_dataset.output_classes)), - target_device=target_device, - string_arg=input_iterator_handle, - buffer_size=buffer_size, - shared_name=shared_name) - self._buffering_resources.append(buffer_resource_handle) - - if not self._one_shot: - reset_ops = [] - for buffer_resource in self._buffering_resources: - reset_ops.append( - ged_ops.experimental_function_buffering_resource_reset( - buffer_resource)) - with ops.control_dependencies(reset_ops): - self._initializer = self._input_iterator.make_initializer( - self._input_dataset) - - def get_next(self, name=None): - """See `tf.data.Iterator.get_next`.""" - self._get_next_call_count += 1 - if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: - warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) - - flat_result = [] - # TODO(priyag): This will fail if the input size (typically number of - # batches) is not divisible by number of devices. - # How do we handle that more gracefully / let the user know? - for buffer_resource in self._buffering_resources: - flat_ret = ged_ops.experimental_function_buffering_resource_get_next( - buffer_resource, - output_types=data_nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - name=name) - - ret = sparse.deserialize_sparse_tensors( - data_nest.pack_sequence_as(self.output_types, flat_ret), - self.output_types, self.output_shapes, self.output_classes) - - for tensor, shape in zip( - data_nest.flatten(ret), data_nest.flatten(self.output_shapes)): - if isinstance(tensor, ops.Tensor): - tensor.set_shape(shape) - flat_result.append(ret) - - return nest.pack_sequence_as(self._devices, flat_result) - - @property - def initializer(self): - if self._one_shot: - raise NotImplementedError("Can't initialize a one_shot_iterator") - return self._initializer - - @property - def output_classes(self): - return self._input_dataset.output_classes - - @property - def output_shapes(self): - return self._input_dataset.output_shapes - - @property - def output_types(self): - return self._input_dataset.output_types - - -# pylint: enable=protected-access - - -class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset): - """A `Dataset` whose iterator prefetches elements to other device(s).""" - - def __init__(self, input_dataset, devices, buffer_size): - super(_PrefetchToDeviceDataset, self).__init__(input_dataset) - self._input_dataset = input_dataset - self._devices = devices - self._buffer_size = buffer_size if buffer_size is not None else 1 - - def make_one_shot_iterator(self): - return _PrefetchToDeviceIterator( - self._input_dataset, - one_shot=True, - devices=self._devices, - buffer_size=self._buffer_size) - - def make_initializable_iterator(self, shared_name=None): - if context.executing_eagerly(): - raise RuntimeError( - "make_initializable_iterator is not supported when eager " - "execution is enabled.") - - return _PrefetchToDeviceIterator( - self._input_dataset, - one_shot=False, - devices=self._devices, - buffer_size=self._buffer_size, - shared_name=shared_name) - - def _as_variant_tensor(self): - # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset - # transformation methods is called. - # TODO(mrry): Investigate support for chaining further transformations after - # the prefetch, including GPU support. - raise NotImplementedError("`prefetch_to_devices()` must be the last " - "transformation in a dataset pipeline.") - - # TODO(priyag): Fix the output types, shapes and classes to match the result - # of get_next (which has the additional nesting layer of devices now). - @property - def output_types(self): - return self._input_dataset.output_types - - @property - def output_shapes(self): - return self._input_dataset.output_shapes - - @property - def output_classes(self): - return self._input_dataset.output_classes - - -def prefetch_to_devices(devices, buffer_size=None): - """A transformation that prefetches dataset values to the given `devices`. - - NOTE: Although the transformation creates a `tf.data.Dataset`, the - transformation must be the final `Dataset` in the input pipeline. - - Args: - devices: A nested structure of devices on which to prefetch the data. It can - be a single device name, or a tuple or list of device names. - buffer_size: (Optional.) The number of elements to buffer on each device. - Defaults to an automatically chosen value. - - Returns: - A `Dataset` transformation function, which can be passed to - `tf.data.Dataset.apply`. - """ - - def _apply_fn(dataset): - return _PrefetchToDeviceDataset(dataset, devices, buffer_size) - - return _apply_fn diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py deleted file mode 100644 index 16799104e8112f4391152c0cf2a15af81f8c2c9d..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for prefetching_ops_v2.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distribute.python import prefetching_ops_v2 -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import errors -from tensorflow.python.framework import test_util -from tensorflow.python.platform import test - - -class PrefetchingOpsV2Test(test.TestCase): - - def testPrefetchToOneDevice(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops_v2.prefetch_to_devices("/gpu:0")) - - iterator = device_dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - with self.cached_session() as sess: - for i in range(10): - self.assertEqual(i, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testPrefetchToTwoDevicesInAList(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"])) - - iterator = device_dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - output = [] - # TODO(rohanj): Modify test to go till the end of the dataset when we - # switch to MultiDeviceIterator. - with self.cached_session() as sess: - for _ in range(4): - result = sess.run(next_element) - self.assertEqual(2, len(result)) - output.extend(result) - self.assertEquals(set(range(8)), set(output)) - - def testPrefetchToTwoDevicesWithReinit(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"])) - - iterator = device_dataset.make_initializable_iterator() - next_element = iterator.get_next() - - # TODO(rohanj): Modify test to go till the end of the dataset when we - # switch to MultiDeviceIterator. - with self.cached_session() as sess: - sess.run(iterator.initializer) - for _ in range(4): - sess.run(next_element) - sess.run(iterator.initializer) - for _ in range(4): - sess.run(next_element) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py index 1b5a4f64e5bb1ffabfe1b87c150f713c755bb682..23bf36184fa61105b43c71ada6c343b30dce8376 100644 --- a/tensorflow/contrib/distribute/python/step_fn.py +++ b/tensorflow/contrib/distribute/python/step_fn.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.eager import backprop +from tensorflow.python.eager import context from tensorflow.python.training import optimizer as optimizer_lib @@ -50,7 +51,11 @@ class StandardInputStep(Step): def __init__(self, dataset_fn, distribution): super(StandardInputStep, self).__init__(distribution) self._distributed_input = distribution.distribute_dataset(dataset_fn) - self._iterator = self._distributed_input.make_one_shot_iterator() + if context.executing_eagerly(): + self._iterator = self._distributed_input.make_one_shot_iterator() + else: + # TODO(priyag): Expose initializer via some initializer property. + self._iterator = self._distributed_input.make_initializable_iterator() class StandardSingleLossStep(StandardInputStep): diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py index f1ada49fa378358f112fb75a4bcdbe9a8a09cd13..1ff9b9ceec13351b098d47ed3ff62f689a625a31 100644 --- a/tensorflow/contrib/distribute/python/step_fn_test.py +++ b/tensorflow/contrib/distribute/python/step_fn_test.py @@ -50,6 +50,7 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase): run_step = single_loss_step else: with self.cached_session() as sess: + sess.run(single_loss_step._iterator.initializer) run_step = sess.make_callable(single_loss_step()) self.evaluate(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index 5d498fb629d4a381f56aa7b2db95b09da9010a78..fd280f5754b34170cdd6b948236138d0e77dd8bc 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -115,7 +115,8 @@ class DistributionTestBase(test.TestCase): with ops.control_dependencies([fetched]): g = d.reduce( variable_scope.VariableAggregation.SUM, g, destinations=v) - with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + with ops.control_dependencies(d.update( + v, update, g, grouped=False)): after_list.append(d.read_var(v)) return before_list, after_list @@ -169,7 +170,8 @@ class DistributionTestBase(test.TestCase): with ops.control_dependencies([fetched]): g = d.reduce( variable_scope.VariableAggregation.SUM, g, destinations=v) - with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + with ops.control_dependencies(d.update( + v, update, g, grouped=False)): after_list.append(d.read_var(v)) return before_list, after_list diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 1b555482d39b92fe3d1df227bda3930df9cdb08b..1d9e299b38409b874610765e54fa0052fafd5f4b 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -132,7 +132,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): """ # TODO(sourabhbajaj): OneDeviceStrategy should be initialized with the # master node fetched from the cluster resolver. - super(TPUStrategy, self).__init__('/device:CPU:0') + super(TPUStrategy, self).__init__("/device:CPU:0") self._tpu_cluster_resolver = tpu_cluster_resolver self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) @@ -152,6 +152,8 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): # at a time is comparable to multiple steps. self.steps_per_run = steps_per_run + self._require_static_shapes = True + def _get_enqueue_op_per_host(self, host_id, iterator, input_shapes, iterations): """Create an enqueue op for a single host identified using host_id. @@ -297,6 +299,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): # For outputs that have already been aggregated, take the first value # from the list as each value should be the same. Else return the full # list of values. + # TODO(josh11b): If aggregation is NONE, we should return a PerDevice value. if aggregation is not variables_lib.VariableAggregation.NONE: # TODO(priyag): Should this return the element or a list with 1 element last_step_tensor_outputs_dict[name] = output[0] @@ -398,11 +401,16 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): return output * (1. / len(value)) return output - def _update(self, var, fn, *args, **kwargs): - # TODO(jhseu): Consider supporting grouped==False. + def _update(self, var, options, fn, *args, **kwargs): assert isinstance(var, values.TPUMirroredVariable) + should_group = options.pop("grouped") + assert not options # Validate that we are processing all of the options. + if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access - return fn(var, *args, **kwargs) + if should_group: + return fn(var, *args, **kwargs) + else: + return [fn(var, *args, **kwargs)] # Otherwise, we revert to MirroredStrategy behavior and update each variable # directly. @@ -414,23 +422,25 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) + return values.update_regroup(self, updates, should_group) - # Make a single control dependency to keep the variables mirrored. If one - # assignment is fetched, then run all assignments. - sorted_keys = sorted(updates.keys()) - update_tuple = control_flow_ops.tuple([updates[d] for d in sorted_keys]) - for i, d in enumerate(sorted_keys): - updates[d] = update_tuple[i] - return values.regroup(updates, values.Mirrored) + # TODO(josh11b): Need to implement _update_non_slot()! def read_var(self, var): assert isinstance(var, values.TPUMirroredVariable) return var.read_value() - def _unwrap(self, value): - if isinstance(value, list): - return value - return [value] + def _unwrap(self, val): + if isinstance(val, values.DistributedValues): + # Return in a deterministic order. + return [val.get(device=d) for d in sorted(val.devices)] + elif isinstance(val, list): + # TODO(josh11b): We need to remove this case; per device values should + # be represented using a PerDevice wrapper instead of a list with + # one entry per device. + return val + return [val] + @property def num_towers(self): diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index c18faeb67d57b7bef764c490c35a56a5f5a84b84..c555dc8a71d0ce56caa95265b39a85578ee85545 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -27,7 +27,7 @@ import weakref import six from tensorflow.contrib.distribute.python import input_ops -from tensorflow.contrib.distribute.python import prefetching_ops_v2 +from tensorflow.python.data.ops import multi_device_iterator_ops from tensorflow.python.eager import context from tensorflow.python.eager import tape from tensorflow.python.framework import device as tf_device @@ -366,18 +366,7 @@ class MirroredVariable(DistributedVariable, Mirrored, # We are calling assign on the mirrored variable in cross tower context, # use update to update the variable. strategy = distribution_strategy_context.get_distribution_strategy() - updates = strategy.update(self, f, *args, **kwargs) - grouped = strategy.group(updates) - if isinstance(updates, DistributedValues) and updates.is_tensor_like: - # Make sure we run all updates. Without this, something like - # session.run(mirrored_var.assign*(...)) may only update one tower. - index = {} - for d in updates.devices: - with ops.device(d), ops.control_dependencies([grouped]): - index[d] = array_ops.identity(updates.get(d)) - return Mirrored(index) - else: - return grouped + return strategy.update(self, f, *args, **kwargs) else: _assert_tower_context() # We are calling an assign function on the mirrored variable in tower @@ -486,6 +475,11 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase): self._aggregation = aggregation # Needed for GradientTape self._trainable = self._primary_var.trainable + # Typically like `DistributedVariable`, a `TPUMirroredVariable`'s + # initializer is composed of the initializers of the components variables. + # However, in some cases, such as when restoring from a checkpoint, we may + # set the _initializer_op property on the entire `TPUMirroredVariable`. + self._initializer_op = None def _get(self, device=None): """Returns the value for the current device or raises a ValueError.""" @@ -582,6 +576,10 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase): ValueError("Device %s not found in %s (current device %s)" % (device, self._index.keys(), device_util.current())), e) + @property + def device(self): + return self._get().device + # The arguments to update() are automatically unwrapped so the update() # function would normally see regular variables, not MirroredVariables. # However, the update function can still operate on wrapped MirroredVariables @@ -711,8 +709,12 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase): @property def initializer(self): - return control_flow_ops.group( - [v.initializer for v in nest.flatten(self._index)]) + if self._initializer_op: + init_op = self._initializer_op + else: + init_op = control_flow_ops.group( + [v.initializer for v in self._index.values()]) + return init_op @property def graph(self): @@ -1049,6 +1051,29 @@ def select_device_mirrored(device, structured): return nest.map_structure(_get_mirrored, structured) +def update_regroup(strategy, updates, should_group): + """Regroup for an update, with dependencies to ensure all updates execute.""" + regrouped = regroup(updates, Mirrored) + if not should_group: + return nest.map_structure(strategy.unwrap, regrouped) + grouped_flat = [] + for u in nest.flatten(regrouped): + if isinstance(u, DistributedValues): + g = strategy.group(u) + if u.is_tensor_like: + # Make sure we run all updates. Without this, something like + # session.run(strategy.update(...)) may only update one tower. + index = {} + for d in u.devices: + with ops.device(d), ops.control_dependencies([g]): + index[d] = array_ops.identity(u.get(d)) + g = Mirrored(index) + else: + g = u + grouped_flat.append(g) + return nest.pack_sequence_as(regrouped, grouped_flat) + + class PerDeviceDataIterator(object): """An iterator (like `tf.data.Iterator`) into a `PerDeviceDataset`.""" @@ -1064,7 +1089,7 @@ class PerDeviceDataIterator(object): def get_next(self, name=None): """Scatter the input across devices.""" if self._prefetch_on_device: - data_list = self._iterator.get_next(name=name) + data_list = self._iterator.get_next() index = dict(zip(self._devices, data_list)) else: batch = self._iterator.get_next(name=name) @@ -1088,17 +1113,15 @@ class PerDeviceDataset(object): self._devices = devices # Default to using prefetching in graph mode, unless specified. - # TODO(priyag): Enable prefetching in eager mode. + # TODO(rohanj): Enable prefetching in eager mode. self._prefetch_on_device = prefetch_on_device if self._prefetch_on_device is None: self._prefetch_on_device = not context.executing_eagerly() assert not (self._prefetch_on_device and context.executing_eagerly()), ( "Prefetching is only supported in graph mode currently") - if self._prefetch_on_device: - self._dataset = dataset.apply( - prefetching_ops_v2.prefetch_to_devices(self._devices)) - else: + self._dataset = dataset + if not self._prefetch_on_device: # TODO(priyag): If dropping remainder is not appropriate, find another # approach to distributing the dataset when not possible to divide evenly. # Possibly not an issue when we start using PartitionedDataset. @@ -1106,15 +1129,33 @@ class PerDeviceDataset(object): def make_one_shot_iterator(self): """Get a one time use iterator for the distributed PerDeviceDataset.""" + # Graph mode with one shot iterator is disabled. + if not context.executing_eagerly(): + raise ValueError("Cannot create a one shot iterator. Please use " + "`make_initializable_iterator()` instead.") + # Eager mode prefetching would error out in constructor. Only remaining + # case is non-prefetching in eager mode. We delegate to + # PerDeviceDataIterator to handle that case. dataset_iterator = self._dataset.make_one_shot_iterator() - return PerDeviceDataIterator(dataset_iterator, self._devices, - self._prefetch_on_device) + return PerDeviceDataIterator( + dataset_iterator, self._devices, prefetch_on_device=False) def make_initializable_iterator(self): """Get an initializable iterator for the distributed PerDeviceDataset.""" - dataset_iterator = self._dataset.make_initializable_iterator() - return PerDeviceDataIterator(dataset_iterator, self._devices, - self._prefetch_on_device) + # Eager mode generates already initialized iterators. Hence we cannot create + # an initializable iterator. + if context.executing_eagerly(): + raise ValueError("Cannot create initializable iterator in Eager mode. " + "Please use `make_one_shot_iterator` instead.") + if self._prefetch_on_device: + dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator( + self._dataset, self._devices) + else: + dataset_iterator = self._dataset.make_initializable_iterator() + return PerDeviceDataIterator( + dataset_iterator, + self._devices, + prefetch_on_device=self._prefetch_on_device) class MultiWorkerDataIterator(object): diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index ae3e13433364d946a92b951f9fe79ff820212b8a..7ef4776ac6d414470c1597358063f6e77960728f 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -349,7 +349,11 @@ class PerDeviceDatasetTest(test.TestCase): def _test_iterator_no_prefetch(self, devices, dataset, expected_values): per_device_dataset = values.PerDeviceDataset( dataset, devices, prefetch_on_device=False) - iterator = per_device_dataset.make_one_shot_iterator() + if context.executing_eagerly(): + iterator = per_device_dataset.make_one_shot_iterator() + else: + iterator = per_device_dataset.make_initializable_iterator() + self.evaluate([iterator.initializer]) for expected_value in expected_values: next_element = iterator.get_next() @@ -366,21 +370,14 @@ class PerDeviceDatasetTest(test.TestCase): if not context.executing_eagerly(): per_device_dataset = values.PerDeviceDataset( dataset, devices, prefetch_on_device=True) - iterator = per_device_dataset.make_one_shot_iterator() + iterator = per_device_dataset.make_initializable_iterator() + self.evaluate([iterator.initializer]) - # With prefetching, we cannot guarantee which input ends up on which - # device, so we verify that the complete set seen on all devices is - # correct, and equal numbers are distributed to each device. - combined_actual = [] - combined_expected = [] for expected_value in expected_values: next_element = iterator.get_next() - combined_actual.extend( - self.evaluate( - [values.select_device(d, next_element) for d in devices])) - combined_expected.extend(expected_value) - - self.assertEqual(set(combined_expected), set(combined_actual)) + computed_value = self.evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() @@ -641,7 +638,7 @@ class MirroredVariableTest(test.TestCase): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") - with self.test_session() as sess: + with self.cached_session(config=self.config) as sess: v, devices, mirrored = _make_mirrored() # Overwrite the initial values. @@ -744,7 +741,7 @@ class MirroredVariableTest(test.TestCase): if context.num_gpus() < 1 or context.executing_eagerly(): self.skipTest("A GPU is not available for this test or it's eager mode.") - with self.test_session( + with self.session( graph=ops.Graph()) as sess, mirrored_strategy.MirroredStrategy( ["/device:GPU:0"]).scope(): with ops.device("/device:GPU:0"): @@ -827,7 +824,7 @@ class TowerLocalVariableTest(test.TestCase): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") - with self.test_session() as sess: + with self.cached_session(config=self.config) as sess: v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM) # Overwrite the initial values. @@ -850,7 +847,7 @@ class TowerLocalVariableTest(test.TestCase): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") - with self.test_session() as sess: + with self.cached_session(config=self.config) as sess: v, tower_local = _make_tower_local( variable_scope.VariableAggregation.MEAN) diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 3ff7da4f89c1145b323c52d671675343a4f5e98c..60f6b90edcb71f04bca29b90744db201e83cd545 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -299,7 +299,7 @@ cuda_py_test( cuda_py_test( name = "mvn_diag_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/mvn_diag_test.py"], additional_deps = [ ":distributions_py", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 5cec93c4df2e970f203253be6342bb292f296eb0..5f6b7fe30996aa97653d97bffb007703437c3d14 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -13,74 +13,80 @@ # limitations under the License. # ============================================================================== """Classes representing statistical distributions and ops for working with them. + +Use [tfp.distributions](/probability/api_docs/python/tfp/distributions) instead. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member +from tensorflow.python.util import deprecation + + +# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member,g-import-not-at-top -from tensorflow.contrib.distributions.python.ops import bijectors -from tensorflow.contrib.distributions.python.ops.autoregressive import * -from tensorflow.contrib.distributions.python.ops.batch_reshape import * -from tensorflow.contrib.distributions.python.ops.binomial import * -from tensorflow.contrib.distributions.python.ops.cauchy import * -from tensorflow.contrib.distributions.python.ops.chi2 import * -from tensorflow.contrib.distributions.python.ops.conditional_distribution import * -from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import * -from tensorflow.contrib.distributions.python.ops.deterministic import * -from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular -from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular_inverse -from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform -from tensorflow.contrib.distributions.python.ops.distribution_util import reduce_weighted_logsumexp -from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse -from tensorflow.contrib.distributions.python.ops.distribution_util import tridiag -from tensorflow.contrib.distributions.python.ops.estimator import * -from tensorflow.contrib.distributions.python.ops.geometric import * -from tensorflow.contrib.distributions.python.ops.half_normal import * -from tensorflow.contrib.distributions.python.ops.independent import * -from tensorflow.contrib.distributions.python.ops.inverse_gamma import * -from tensorflow.contrib.distributions.python.ops.kumaraswamy import * -from tensorflow.contrib.distributions.python.ops.logistic import * -from tensorflow.contrib.distributions.python.ops.mixture import * -from tensorflow.contrib.distributions.python.ops.mixture_same_family import * -from tensorflow.contrib.distributions.python.ops.moving_stats import * -from tensorflow.contrib.distributions.python.ops.mvn_diag import * -from tensorflow.contrib.distributions.python.ops.mvn_diag_plus_low_rank import * -from tensorflow.contrib.distributions.python.ops.mvn_full_covariance import * -from tensorflow.contrib.distributions.python.ops.mvn_tril import * -from tensorflow.contrib.distributions.python.ops.negative_binomial import * -from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import * -from tensorflow.contrib.distributions.python.ops.onehot_categorical import * -from tensorflow.contrib.distributions.python.ops.poisson import * -from tensorflow.contrib.distributions.python.ops.poisson_lognormal import * -from tensorflow.contrib.distributions.python.ops.quantized_distribution import * -from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import * -from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import * -from tensorflow.contrib.distributions.python.ops.sample_stats import * -from tensorflow.contrib.distributions.python.ops.seed_stream import * -from tensorflow.contrib.distributions.python.ops.sinh_arcsinh import * -from tensorflow.contrib.distributions.python.ops.test_util import * -from tensorflow.contrib.distributions.python.ops.vector_diffeomixture import * -from tensorflow.contrib.distributions.python.ops.vector_exponential_diag import * -from tensorflow.contrib.distributions.python.ops.vector_laplace_diag import * -from tensorflow.contrib.distributions.python.ops.vector_sinh_arcsinh_diag import * -from tensorflow.contrib.distributions.python.ops.wishart import * -from tensorflow.python.ops.distributions.bernoulli import * -from tensorflow.python.ops.distributions.beta import * -from tensorflow.python.ops.distributions.categorical import * -from tensorflow.python.ops.distributions.dirichlet import * -from tensorflow.python.ops.distributions.dirichlet_multinomial import * -from tensorflow.python.ops.distributions.distribution import * -from tensorflow.python.ops.distributions.exponential import * -from tensorflow.python.ops.distributions.gamma import * -from tensorflow.python.ops.distributions.kullback_leibler import * -from tensorflow.python.ops.distributions.laplace import * -from tensorflow.python.ops.distributions.multinomial import * -from tensorflow.python.ops.distributions.normal import * -from tensorflow.python.ops.distributions.student_t import * -from tensorflow.python.ops.distributions.transformed_distribution import * -from tensorflow.python.ops.distributions.uniform import * +with deprecation.silence(): + from tensorflow.contrib.distributions.python.ops import bijectors + from tensorflow.contrib.distributions.python.ops.autoregressive import * + from tensorflow.contrib.distributions.python.ops.batch_reshape import * + from tensorflow.contrib.distributions.python.ops.binomial import * + from tensorflow.contrib.distributions.python.ops.cauchy import * + from tensorflow.contrib.distributions.python.ops.chi2 import * + from tensorflow.contrib.distributions.python.ops.conditional_distribution import * + from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import * + from tensorflow.contrib.distributions.python.ops.deterministic import * + from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular + from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular_inverse + from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform + from tensorflow.contrib.distributions.python.ops.distribution_util import reduce_weighted_logsumexp + from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse + from tensorflow.contrib.distributions.python.ops.distribution_util import tridiag + from tensorflow.contrib.distributions.python.ops.estimator import * + from tensorflow.contrib.distributions.python.ops.geometric import * + from tensorflow.contrib.distributions.python.ops.half_normal import * + from tensorflow.contrib.distributions.python.ops.independent import * + from tensorflow.contrib.distributions.python.ops.inverse_gamma import * + from tensorflow.contrib.distributions.python.ops.kumaraswamy import * + from tensorflow.contrib.distributions.python.ops.logistic import * + from tensorflow.contrib.distributions.python.ops.mixture import * + from tensorflow.contrib.distributions.python.ops.mixture_same_family import * + from tensorflow.contrib.distributions.python.ops.moving_stats import * + from tensorflow.contrib.distributions.python.ops.mvn_diag import * + from tensorflow.contrib.distributions.python.ops.mvn_diag_plus_low_rank import * + from tensorflow.contrib.distributions.python.ops.mvn_full_covariance import * + from tensorflow.contrib.distributions.python.ops.mvn_tril import * + from tensorflow.contrib.distributions.python.ops.negative_binomial import * + from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import * + from tensorflow.contrib.distributions.python.ops.onehot_categorical import * + from tensorflow.contrib.distributions.python.ops.poisson import * + from tensorflow.contrib.distributions.python.ops.poisson_lognormal import * + from tensorflow.contrib.distributions.python.ops.quantized_distribution import * + from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import * + from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import * + from tensorflow.contrib.distributions.python.ops.sample_stats import * + from tensorflow.contrib.distributions.python.ops.seed_stream import * + from tensorflow.contrib.distributions.python.ops.sinh_arcsinh import * + from tensorflow.contrib.distributions.python.ops.test_util import * + from tensorflow.contrib.distributions.python.ops.vector_diffeomixture import * + from tensorflow.contrib.distributions.python.ops.vector_exponential_diag import * + from tensorflow.contrib.distributions.python.ops.vector_laplace_diag import * + from tensorflow.contrib.distributions.python.ops.vector_sinh_arcsinh_diag import * + from tensorflow.contrib.distributions.python.ops.wishart import * + from tensorflow.python.ops.distributions.bernoulli import * + from tensorflow.python.ops.distributions.beta import * + from tensorflow.python.ops.distributions.categorical import * + from tensorflow.python.ops.distributions.dirichlet import * + from tensorflow.python.ops.distributions.dirichlet_multinomial import * + from tensorflow.python.ops.distributions.distribution import * + from tensorflow.python.ops.distributions.exponential import * + from tensorflow.python.ops.distributions.gamma import * + from tensorflow.python.ops.distributions.kullback_leibler import * + from tensorflow.python.ops.distributions.laplace import * + from tensorflow.python.ops.distributions.multinomial import * + from tensorflow.python.ops.distributions.normal import * + from tensorflow.python.ops.distributions.student_t import * + from tensorflow.python.ops.distributions.transformed_distribution import * + from tensorflow.python.ops.distributions.uniform import * # pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index e141f8b5c6423bd6cce4d09da6f49d55b3e25a24..3b17de9b8a903956bfdc4d46cf5bbfbfd8530e9f 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -14,6 +14,8 @@ # ============================================================================== """Bijector Ops. +Use [tfp.bijectors](/probability/api_docs/python/tfp/bijectors) instead. + @@AbsoluteValue @@Affine @@AffineLinearOperator diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 33a1d572a20e68479d3ec1147d4892449e7beb8a..77052a75a70bec1162feb2b126d247924b3a2e36 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -28,6 +28,7 @@ py_library( "//tensorflow/python:variable_scope", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", + "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:execution_callbacks", "//tensorflow/python/eager:function", ], @@ -249,11 +250,10 @@ py_library( ], ) -py_test( +cuda_py_test( name = "remote_test", srcs = ["remote_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ ":parameter_server", ":remote", "//tensorflow/contrib/eager/python:tfe", diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index 135095a97980da8988b976948fb18492526e390c..3aed121233be1268531495a2fa83fd323412e1fd 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import prefetching_ops +from tensorflow.python.data.experimental.ops import prefetching_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import context from tensorflow.python.framework import ops @@ -54,7 +54,7 @@ class Iterator(iterator_ops.EagerIterator): """ if isinstance(dataset, prefetching_ops._PrefetchToDeviceDataset): # pylint: disable=protected-access raise TypeError( - "`tf.contrib.data.prefetch_to_device()` is not compatible with " + "`tf.data.experimental.prefetch_to_device()` is not compatible with " "`tf.contrib.eager.Iterator`. Use `for ... in dataset:` to iterate " "over the dataset instead.") diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index a753d77580758af9de8410de4a08f7ea278c4c79..6a508fc6ba98740c4d441a064dc8a3e2b321f585 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -24,11 +24,11 @@ import time import numpy as np from tensorflow.contrib import lookup -from tensorflow.contrib.data.python.ops import prefetching_ops -from tensorflow.contrib.data.python.ops import threadpool -from tensorflow.contrib.data.python.ops import unique from tensorflow.contrib.eager.python import datasets from tensorflow.python.data import Dataset +from tensorflow.python.data.experimental.ops import prefetching_ops +from tensorflow.python.data.experimental.ops import threadpool +from tensorflow.python.data.experimental.ops import unique from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py b/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py index e5058bfd9480e25b3cf040f0d96bf21242a147b8..a9fb0035d299d64b35d756eaf1ae5f7034ff5599 100644 --- a/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py +++ b/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py @@ -228,6 +228,7 @@ class DensenetBenchmark(tf.test.Benchmark): weight_decay=1e-4, dropout_rate=0, pool_initial=True, include_top=True) if defun: + # TODO(apassos) enable tfe.function here model.call = tfe.defun(model.call) batch_size = 64 num_burn = 5 diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb index e0d5e494d432b365b0d1dcff6b634de2e6213a43..bda9e77085e45ae31a228142135425e22a1c6780 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb @@ -567,7 +567,7 @@ "\n", "* We get predictions using the start_string and the hidden state\n", "\n", - "* Then we use a multinomial distribution to calculate the index of the predicted word. **We use this predicted word as our next input to the model**\n", + "* Then we use argmax to calculate the index of the predicted word. **We use this predicted word as our next input to the model**\n", "\n", "* **The hidden state returned by the model is fed back into the model so that it now has more context rather than just one word.** After we predict the next word, the modified hidden states are again fed back into the model, which is how it learns as it gets more context from the previously predicted words.\n", "\n", @@ -598,19 +598,13 @@ "# empty string to store our results\n", "text_generated = ''\n", "\n", - "# low temperatures results in more predictable text.\n", - "# higher temperatures results in more surprising text\n", - "# experiment to find the best setting\n", - "temperature = 1.0\n", - "\n", "# hidden state shape == (batch_size, number of rnn units); here batch size == 1\n", "hidden = [tf.zeros((1, units))]\n", "for i in range(num_generate):\n", " predictions, hidden = model(input_eval, hidden)\n", "\n", - " # using a multinomial distribution to predict the word returned by the model\n", - " predictions = predictions / temperature\n", - " predicted_id = tf.argmax(predictions[0]).numpy()\n", + " # using argmax to predict the word returned by the model\n", + " predicted_id = tf.argmax(predictions[-1]).numpy()\n", " \n", " # We pass the predicted word as the next input to the model\n", " # along with the previous hidden state\n", @@ -632,7 +626,6 @@ "\n", "* Change the start string to a different character, or the start of a sentence.\n", "* Experiment with training on a different, or with different parameters. [Project Gutenberg](http://www.gutenberg.org/ebooks/100), for example, contains a large collection of books.\n", - "* Experiment with the temperature parameter.\n", "* Add another RNN layer.\n" ] }, diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/main.py b/tensorflow/contrib/eager/python/examples/l2hmc/main.py index 45e1f98429f48749d374c2aefd8874690c3830ad..98fcb2ba10aa4148dc1d4bd7ddfb6fa9c8c4537c 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/main.py +++ b/tensorflow/contrib/eager/python/examples/l2hmc/main.py @@ -71,7 +71,7 @@ def main(_): # Training if FLAGS.use_defun: # Use `tfe.deun` to boost performance when there are lots of small ops - loss_fn = tfe.defun(l2hmc.compute_loss) + loss_fn = tfe.function(l2hmc.compute_loss) else: loss_fn = l2hmc.compute_loss @@ -104,7 +104,7 @@ def main(_): # Evaluation if FLAGS.use_defun: # Use tfe.deun to boost performance when there are lots of small ops - apply_transition = tfe.defun(dynamics.apply_transition) + apply_transition = tfe.function(dynamics.apply_transition) else: apply_transition = dynamics.apply_transition diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb index 560fc8c5a22a0e7acf1f37cf7daf7790dc14de19..480777d948769b56ac1cc3be2052fe48459e98d6 100644 --- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb @@ -352,7 +352,7 @@ "And the pseudo-code:\n", "\n", "* `score = FC(tanh(FC(EO) + FC(H)))`\n", - "* `attention weights = softmax(score, axis = 1)`. Softmax by default is applied on the last axis but here we want to apply it on the *1st axis*, since the shape of score is *(batch_size, max_length, hidden_size)*. `Max_length` is the length of our input. Since we are trying to assign a weight to each input, softmax should be applied on that axis.\n", + "* `attention weights = softmax(score, axis = 1)`. Softmax by default is applied on the last axis but here we want to apply it on the *1st axis*, since the shape of score is *(batch_size, max_length, 1)*. `Max_length` is the length of our input. Since we are trying to assign a weight to each input, softmax should be applied on that axis.\n", "* `context vector = sum(attention weights * EO, axis = 1)`. Same reason as above for choosing axis as 1.\n", "* `embedding output` = The input to the decoder X is passed through an embedding layer.\n", "* `merged vector = concat(embedding output, context vector)`\n", @@ -446,12 +446,12 @@ " # we are doing this to perform addition to calculate the score\n", " hidden_with_time_axis = tf.expand_dims(hidden, 1)\n", " \n", - " # score shape == (batch_size, max_length, hidden_size)\n", - " score = tf.nn.tanh(self.W1(enc_output) + self.W2(hidden_with_time_axis))\n", + " # score shape == (batch_size, max_length, 1)\n", + " # we get 1 at the last axis because we are applying tanh(FC(EO) + FC(H)) to self.V\n", + " score = self.V(tf.nn.tanh(self.W1(enc_output) + self.W2(hidden_with_time_axis)))\n", " \n", " # attention_weights shape == (batch_size, max_length, 1)\n", - " # we get 1 at the last axis because we are applying score to self.V\n", - " attention_weights = tf.nn.softmax(self.V(score), axis=1)\n", + " attention_weights = tf.nn.softmax(score, axis=1)\n", " \n", " # context_vector shape after sum == (batch_size, hidden_size)\n", " context_vector = attention_weights * enc_output\n", diff --git a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb index 8fae622e12864ddeee0cedd3cf99be8ea5e4bc48..446e3401184ded6bc34ed64cdd720e29a2851855 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb @@ -65,7 +65,7 @@ "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb\"\u003e\n", " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", "\u003c/td\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" ] } ], diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py index 551c76b0df71c88919df9cd6d81b4176b23b0ba3..f3bb978875e226f58d6a00e09154191673a97415 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py @@ -51,7 +51,9 @@ def random_batch(batch_size): class ResNet50GraphTest(tf.test.TestCase): def testApply(self): - batch_size = 64 + # Use small batches for tests because the OSS version runs + # in constrained GPU environment with 1-2GB of memory. + batch_size = 8 with tf.Graph().as_default(): images = tf.placeholder(tf.float32, image_shape(None)) model = resnet50.ResNet50(data_format()) @@ -63,7 +65,7 @@ class ResNet50GraphTest(tf.test.TestCase): sess.run(init) np_images, _ = random_batch(batch_size) out = sess.run(predictions, feed_dict={images: np_images}) - self.assertAllEqual([64, 1000], out.shape) + self.assertAllEqual([batch_size, 1000], out.shape) def testTrainWithSummary(self): with tf.Graph().as_default(): @@ -87,7 +89,9 @@ class ResNet50GraphTest(tf.test.TestCase): init = tf.global_variables_initializer() self.assertEqual(321, len(tf.global_variables())) - batch_size = 32 + # Use small batches for tests because the OSS version runs + # in constrained GPU environment with 1-2GB of memory. + batch_size = 2 with tf.Session() as sess: sess.run(init) sess.run(tf.contrib.summary.summary_writer_initializer_op()) diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index d265169b5eff685f7b79fb221b9bd52be37ead9c..fb81979d7bd8d17a55b8c448008765268dd07d1d 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -77,7 +77,7 @@ class ResNet50Test(tf.test.TestCase): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format) if defun: - model.call = tfe.defun(model.call) + model.call = tfe.function(model.call) with tf.device(device), tfe.execution_mode(execution_mode): images, _ = random_batch(2, data_format) output = model(images, training=False) @@ -221,7 +221,7 @@ class ResNet50Benchmarks(tf.test.Benchmark): device, data_format = device_and_format model = resnet50.ResNet50(data_format) if defun: - model.call = tfe.defun(model.call) + model.call = tfe.function(model.call) batch_size = 64 num_burn = 5 num_iters = 30 @@ -266,8 +266,8 @@ class ResNet50Benchmarks(tf.test.Benchmark): optimizer = tf.train.GradientDescentOptimizer(0.1) apply_grads = apply_gradients if defun: - model.call = tfe.defun(model.call) - apply_grads = tfe.defun(apply_gradients) + model.call = tfe.function(model.call) + apply_grads = tfe.function(apply_gradients) num_burn = 3 num_iters = 10 diff --git a/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py b/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py index 34a9984b0ecc527ad1991c28146246b716e96c98..d85188de030af2bbab1c141b5c090371248110b9 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py +++ b/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py @@ -169,11 +169,11 @@ class ImageNetInput(object): # Read the data from disk in parallel dataset = dataset.apply( - tf.contrib.data.parallel_interleave( + tf.data.experimental.parallel_interleave( fetch_dataset, cycle_length=self.num_parallel_calls, sloppy=True)) if self.cache: dataset = dataset.cache().apply( - tf.contrib.data.shuffle_and_repeat(1024 * 16)) + tf.data.experimental.shuffle_and_repeat(1024 * 16)) else: dataset = dataset.shuffle(1024) @@ -188,9 +188,11 @@ class ImageNetInput(object): # batch size. As long as this validation is done with consistent batch size, # exactly the same images will be used. dataset = dataset.apply( - tf.contrib.data.map_and_batch( - self.dataset_parser, batch_size=batch_size, - num_parallel_batches=self.num_cores, drop_remainder=True)) + tf.data.experimental.map_and_batch( + self.dataset_parser, + batch_size=batch_size, + num_parallel_batches=self.num_cores, + drop_remainder=True)) # Transpose for performance on TPU if self.transpose_input: diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py index 6a921e19978fdf6e3c20974b2c349bd6923b5782..971aa44f3034692dfb0d03ed3dabf4d6e911eb9f 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py @@ -50,6 +50,9 @@ class RevNetTest(tf.test.TestCase): # Reconstruction could cause numerical error, use double precision for tests config.dtype = tf.float64 config.fused = False # Fused batch norm does not support tf.float64 + # Reduce the batch size for tests because the OSS version runs + # in constrained GPU environment with 1-2GB of memory. + config.batch_size = 2 shape = (config.batch_size,) + config.input_shape self.model = revnet.RevNet(config=config) self.x = tf.random_normal(shape=shape, dtype=tf.float64) @@ -124,6 +127,8 @@ class RevNetTest(tf.test.TestCase): def test_compute_gradients_defun(self): """Test `compute_gradients` function with defun.""" + # TODO(apassos): make cond support returning None to let this happen with + # tf.function. compute_gradients = tfe.defun(self.model.compute_gradients) _, saved_hidden = self.model(self.x) grads, _ = compute_gradients(saved_hidden=saved_hidden, labels=self.t) @@ -232,6 +237,7 @@ class RevNetBenchmark(tf.test.Benchmark): device, data_format = device_and_format model = revnet.RevNet(config=config) if defun: + # TODO(apassos): reenable after cond lets you return None model.call = tfe.defun(model.call) batch_size = 64 num_burn = 5 @@ -279,7 +285,7 @@ class RevNetBenchmark(tf.test.Benchmark): model = revnet.RevNet(config=config) optimizer = tf.train.GradientDescentOptimizer(0.1) if defun: - model.call = tfe.defun(model.call) + model.call = tfe.function(model.call) num_burn = 3 num_iters = 10 diff --git a/tensorflow/contrib/eager/python/remote_test.py b/tensorflow/contrib/eager/python/remote_test.py index ba6fe9701df74361f2160195606efbe5bbcb6857..3926de15e71c9917f88fc3f58740b8c75354ab26 100644 --- a/tensorflow/contrib/eager/python/remote_test.py +++ b/tensorflow/contrib/eager/python/remote_test.py @@ -47,8 +47,9 @@ def run_sync_and_async(f): @functools.wraps(f) def decorator(self, *args, **kwargs): - with context.execution_mode(context.ASYNC): - f(self, *args, **kwargs) + # TODO(b/117110239): Re-enable. + # with context.execution_mode(context.ASYNC): + # f(self, *args, **kwargs) with context.execution_mode(context.SYNC): f(self, *args, **kwargs) @@ -205,6 +206,33 @@ class RemoteExecutionTest(test.TestCase): y = math_ops.matmul(x1, x2) np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + @run_sync_and_async + def testContextDeviceUpdated(self): + """Tests that the context device is correctly updated.""" + + with ops.device("cpu:0"): + x1 = array_ops.ones([2, 2]) + x2 = array_ops.ones([2, 2]) + y = math_ops.matmul(x1, x2) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + # `y` is placed on the local CPU as expected. + self.assertEqual(y.device, + "/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME) + + @run_sync_and_async + def testGPUToRemoteCopy(self): + """Tests that the remote copy happens satisfactorily.""" + if not context.context().num_gpus(): + self.skipTest("No GPUs.") + + x1 = array_ops.ones([2, 2]).gpu() + + with ops.device("/job:remote_device/replica:0/task:1/device:CPU:0"): + x2 = x1._copy() # pylint: disable=protected-access + + np.testing.assert_array_equal(x1.numpy(), x2.numpy()) + if __name__ == "__main__": ops.enable_eager_execution() diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index f5b8d95e4fc7fe5cd90d658eda49590e0b330bb0..33c988fd9065e7fbe7b9aeb85cad82eb3c119f76 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -25,6 +25,7 @@ To use, at program startup, call `tf.enable_eager_execution()`. @@py_func @@defun +@@function @@make_template @@implicit_gradients @@implicit_value_and_gradients @@ -101,7 +102,7 @@ from tensorflow.contrib.eager.python.saver import get_optimizer_variables from tensorflow.contrib.eager.python.saver import restore_variables_on_create from tensorflow.contrib.eager.python.saver import Saver from tensorflow.python.eager import backprop -from tensorflow.python.eager import function +from tensorflow.python.eager import function as _function_lib from tensorflow.python.eager.context import DEVICE_PLACEMENT_EXPLICIT from tensorflow.python.eager.context import DEVICE_PLACEMENT_WARN from tensorflow.python.eager.context import DEVICE_PLACEMENT_SILENT @@ -115,6 +116,7 @@ from tensorflow.python.eager.context import SYNC from tensorflow.python.eager.context import ASYNC from tensorflow.python.eager.context import num_gpus from tensorflow.python.eager.context import set_server_def +from tensorflow.python.eager.def_function import function from tensorflow.python.eager.execution_callbacks import add_execution_callback from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks from tensorflow.python.eager.execution_callbacks import inf_callback @@ -138,7 +140,7 @@ from tensorflow.python.training.checkpointable.util import Checkpoint from tensorflow.python.util.all_util import remove_undocumented py_func = script_ops.eager_py_func -defun = function.defun +defun = _function_lib.defun make_template = template.make_template_internal implicit_gradients = backprop.implicit_grad implicit_value_and_gradients = backprop.implicit_val_and_grad diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py index a1f1c5f3d7a25ad28c58e9c215b862b6d51f4cd8..b131ed4f12a01a0087390b5bb65f3ac2d5aec657 100644 --- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py @@ -75,7 +75,7 @@ class _BoostedTreesEstimator(canned_boosted_trees._BoostedTreesBase): # pylint: layer. head: the `Head` instance defined for Estimator. model_dir: Directory to save model parameters, graph and etc. This can - also be used to load checkpoints from the directory into a estimator + also be used to load checkpoints from the directory into an estimator to continue training a previously saved model. weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -199,7 +199,7 @@ def boosted_trees_classifier_train_in_memory( the model. All items in the set should be instances of classes derived from `FeatureColumn`. model_dir: Directory to save model parameters, graph and etc. This can - also be used to load checkpoints from the directory into a estimator + also be used to load checkpoints from the directory into an estimator to continue training a previously saved model. n_classes: number of label classes. Default is binary classification. Multiclass support is not yet implemented. @@ -345,7 +345,7 @@ def boosted_trees_regressor_train_in_memory( the model. All items in the set should be instances of classes derived from `FeatureColumn`. model_dir: Directory to save model parameters, graph and etc. This can - also be used to load checkpoints from the directory into a estimator + also be used to load checkpoints from the directory into an estimator to continue training a previously saved model. label_dimension: Number of regression targets per example. Multi-dimensional support is not yet implemented. diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py index 724bc2c82f8289bbaa19a1dbbc1dc81b6e158e02..4e7965ef265022214f88ed74f4c8502fc8a4c897 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py @@ -118,7 +118,7 @@ class DNNLinearCombinedEstimator(estimator.Estimator): head: A `_Head` instance constructed with a method such as `tf.contrib.estimator.multi_label_head`. model_dir: Directory to save model parameters, graph and etc. This can - also be used to load checkpoints from the directory into a estimator + also be used to load checkpoints from the directory into an estimator to continue training a previously saved model. linear_feature_columns: An iterable containing all the feature columns used by linear part of the model. All items in the set must be diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py index 5faf0aacfe57b3fcd716bfb0f73842e4e8180cbc..40a91175b71f27bb9ca72a238a5aea172cf4c360 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py @@ -151,7 +151,7 @@ def make_input_layer_with_layer_annotations(original_input_layer): # spec and looking at the keys. spec = feature_column_lib.make_parse_example_spec(feature_columns) for key in spec.keys(): - tensor = ops.convert_to_tensor(features[key]) + tensor = ops.convert_to_tensor_or_indexed_slices(features[key]) ops.add_to_collection( LayerAnnotationsCollectionNames.keys( LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES), key) @@ -248,7 +248,7 @@ def DNNClassifierWithLayerAnnotations( # pylint: disable=invalid-name model. All items in the set should be instances of classes derived from `_FeatureColumn`. model_dir: Directory to save model parameters, graph and etc. This can also - be used to load checkpoints from the directory into a estimator to + be used to load checkpoints from the directory into an estimator to continue training a previously saved model. n_classes: Number of label classes. Defaults to 2, namely binary classification. Must be > 1. diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py index ce758992140d43529037b14cbbf958d5aa763fb4..6e793c830244e64cd11c4054918c18a8251be7ac 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py @@ -233,6 +233,22 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access self, features, mode, logits, labels=None, optimizer=None, train_op_fn=None): """See `_Head`.""" + return self._create_estimator_spec( + features=features, mode=mode, logits=logits, labels=labels, + optimizer=optimizer, train_op_fn=train_op_fn, use_tpu=False) + + def _create_tpu_estimator_spec( + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None): + """See `_Head`.""" + return self._create_estimator_spec( + features=features, mode=mode, logits=logits, labels=labels, + optimizer=optimizer, train_op_fn=train_op_fn, use_tpu=True) + + def _create_estimator_spec( + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None, use_tpu=False): + """Returns `EstimatorSpec` or `TPUEstimatorSpec`.""" if isinstance(logits, dict): logits_dict = logits else: @@ -255,14 +271,15 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access spec = self._merge_train( all_estimator_spec=all_estimator_spec, optimizer=optimizer, - train_op_fn=train_op_fn) + train_op_fn=train_op_fn, + use_tpu=use_tpu) with ops.name_scope(''): summary.scalar(metric_keys.MetricKeys.LOSS, spec.loss) return spec if mode == model_fn.ModeKeys.PREDICT: - return self._merge_predict(all_estimator_spec) + return self._merge_predict(all_estimator_spec, use_tpu=use_tpu) if mode == model_fn.ModeKeys.EVAL: - return self._merge_eval(all_estimator_spec) + return self._merge_eval(all_estimator_spec, use_tpu=use_tpu) raise ValueError('mode={} unrecognized'.format(mode)) def _split_logits(self, logits): @@ -284,28 +301,28 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access begin_idx += head.logits_dimension return logits_dict - def _merge_train(self, all_estimator_spec, optimizer, train_op_fn): - """Merges list of `EstimatorSpec` for training. + def _merge_train( + self, all_estimator_spec, optimizer, train_op_fn, use_tpu=False): + """Merges list of `EstimatorSpec` or `TPUEstimatorSpec` for training. Args: - all_estimator_spec: list of `EstimatorSpec` for the individual heads. + all_estimator_spec: list of `EstimatorSpec` or `TPUEstimatorSpec` for the + individual heads. optimizer: `Optimizer` instance to create train op. See `create_estimator_spec` documentation for more details. train_op_fn: Function to create train op. Used if `optimizer` is `None`. + use_tpu: If `True`, returns `TPUEstimatorSpec`. Returns: - `EstimatorSpec` that merges all heads for TRAIN. + `EstimatorSpec` or `TPUEstimatorSpec` that merges all heads for TRAIN. Raises: ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN mode. """ losses = [] - metrics = {} for spec in all_estimator_spec: losses.append(spec.loss) - # Metric keys already contain head.name. - metrics.update(spec.eval_metric_ops or {}) loss = _merge_losses(losses, self._head_weights) if optimizer is not None: if train_op_fn is not None: @@ -317,20 +334,23 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access else: raise ValueError('train_op_fn and optimizer cannot both be None.') - return model_fn.EstimatorSpec( + spec_type = ( + model_fn._TPUEstimatorSpec if use_tpu else model_fn.EstimatorSpec) # pylint:disable=protected-access + return spec_type( mode=model_fn.ModeKeys.TRAIN, loss=loss, - train_op=train_op, - eval_metric_ops=metrics) + train_op=train_op) - def _merge_predict(self, all_estimator_spec): - """Merges list of `EstimatorSpec` for prediction. + def _merge_predict(self, all_estimator_spec, use_tpu=False): + """Merges list of `EstimatorSpec` or `TPUEstimatorSpec` for prediction. Args: - all_estimator_spec: list of `EstimatorSpec` for the individual heads. + all_estimator_spec: list of `EstimatorSpec` or `TPUEstimatorSpec` for the + individual heads. + use_tpu: If `True`, returns `TPUEstimatorSpec`. Returns: - `EstimatorSpec` that merges all heads for PREDICT. + `EstimatorSpec` or `TPUEstimatorSpec` that merges all heads for PREDICT. """ predictions = {} export_outputs = { @@ -357,20 +377,29 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access export_outputs[head_lib._PREDICT_SERVING_KEY] = ( # pylint:disable=protected-access export_output_lib.PredictOutput(merged_predict_outputs)) - return model_fn.EstimatorSpec( + spec_type = ( + model_fn._TPUEstimatorSpec if use_tpu else model_fn.EstimatorSpec) # pylint:disable=protected-access + return spec_type( mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs=export_outputs) - def _merge_eval(self, all_estimator_spec): + def _merge_eval(self, all_estimator_spec, use_tpu=False): """Merges list of `EstimatorSpec` for eval. Args: all_estimator_spec: list of `EstimatorSpec` for the individual heads. + use_tpu: If `True`, will raise `NotImplementedError`, because TPU is not + yet supported for eval. Returns: `EstimatorSpec` that merges all heads for EVAL. + Raises: + NotImplementedError: If `use_tpu` is `True`. """ + if use_tpu: + raise NotImplementedError( + 'TPU evaluation is not implemented for multi_head.') predictions = {} metrics = {} losses = [] diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py index 2b4d5f526199c500ad77a0422215381ac3a1cf69..a602f87b4a2b4062efddf819522fb2d1eeceaabe 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py @@ -106,7 +106,7 @@ class MultiHeadTest(test.TestCase): multi_head = multi_head_lib.multi_head([head1, head2]) self.assertEqual('head1_head2', multi_head.name) - def test_predict_two_heads_logits_dict(self): + def _test_predict_two_heads_logits_dict(self, use_tpu): """Tests predict with logits as dict.""" head1 = head_lib.multi_label_head(n_classes=2, name='head1') head2 = head_lib.multi_label_head(n_classes=3, name='head2') @@ -121,10 +121,16 @@ class MultiHeadTest(test.TestCase): 'head2': _sigmoid(logits['head2']), } - spec = multi_head.create_estimator_spec( - features={'x': np.array(((42,),), dtype=np.int32)}, - mode=model_fn.ModeKeys.PREDICT, - logits=logits) + if use_tpu: + spec = multi_head._create_tpu_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.PREDICT, + logits=logits).as_estimator_spec() + else: + spec = multi_head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.PREDICT, + logits=logits) self.assertItemsEqual( (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/classification', @@ -175,6 +181,12 @@ class MultiHeadTest(test.TestCase): sess.run( spec.export_outputs['head2/predict'].outputs['probabilities'])) + def test_predict_two_heads_logits_dict(self): + self._test_predict_two_heads_logits_dict(use_tpu=False) + + def test_predict_two_heads_logits_dict_tpu(self): + self._test_predict_two_heads_logits_dict(use_tpu=True) + def test_predict_two_heads_logits_tensor(self): """Tests predict with logits as Tensor.""" head1 = head_lib.multi_label_head(n_classes=2, name='head1') @@ -350,6 +362,31 @@ class MultiHeadTest(test.TestCase): rtol=tol, atol=tol) + def test_eval_tpu(self): + head1 = head_lib.multi_label_head(n_classes=2, name='head1') + head2 = head_lib.multi_label_head(n_classes=3, name='head2') + multi_head = multi_head_lib.multi_head( + [head1, head2], head_weights=[1., 2.]) + + logits = { + 'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32), + 'head2': np.array([[20., -20., 20.], [-30., 20., -20.]], + dtype=np.float32), + } + labels = { + 'head1': np.array([[1, 0], [1, 1]], dtype=np.int64), + 'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64), + } + + with self.assertRaisesRegexp( + NotImplementedError, + r'TPU evaluation is not implemented for multi_head\.'): + multi_head._create_tpu_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.EVAL, + logits=logits, + labels=labels) + def test_train_create_loss_one_head(self): head1 = head_lib.multi_label_head(n_classes=2, name='head1') multi_head = multi_head_lib.multi_head([head1]) @@ -587,7 +624,7 @@ class MultiHeadTest(test.TestCase): six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) - def test_train_two_heads_with_weights(self): + def _test_train_two_heads_with_weights(self, use_tpu): head1 = head_lib.multi_label_head(n_classes=2, name='head1') head2 = head_lib.multi_label_head(n_classes=3, name='head2') multi_head = multi_head_lib.multi_head( @@ -619,12 +656,20 @@ class MultiHeadTest(test.TestCase): [constant_op.constant(expected_train_result), string_ops.as_string(loss, precision=3)]) - spec = multi_head.create_estimator_spec( - features={'x': np.array(((42,),), dtype=np.int32)}, - mode=model_fn.ModeKeys.TRAIN, - logits=logits, - labels=labels, - train_op_fn=_train_op_fn) + if use_tpu: + spec = multi_head._create_tpu_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn).as_estimator_spec() + else: + spec = multi_head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) self.assertIsNotNone(spec.loss) self.assertEqual({}, spec.eval_metric_ops) @@ -649,6 +694,12 @@ class MultiHeadTest(test.TestCase): metric_keys.MetricKeys.LOSS + '/head2': expected_loss_head2, }, summary_str, tol) + def test_train_two_heads_with_weights(self): + self._test_train_two_heads_with_weights(use_tpu=False) + + def test_train_two_heads_with_weights_tpu(self): + self._test_train_two_heads_with_weights(use_tpu=True) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py index 98660bb7317ae76a7da7c90a5c890ab8e69037fe..c595f473950e28cd75cd1b56c1b3d409333dbc74 100644 --- a/tensorflow/contrib/estimator/python/estimator/rnn.py +++ b/tensorflow/contrib/estimator/python/estimator/rnn.py @@ -30,7 +30,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.layers import core as core_layers from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import partitioned_variables @@ -92,55 +91,6 @@ def _make_rnn_cell_fn(num_units, cell_type='basic_rnn'): return rnn_cell_fn -def _concatenate_context_input(sequence_input, context_input): - """Replicates `context_input` across all timesteps of `sequence_input`. - - Expands dimension 1 of `context_input` then tiles it `sequence_length` times. - This value is appended to `sequence_input` on dimension 2 and the result is - returned. - - Args: - sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size, - padded_length, d0]`. - context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`. - - Returns: - A `Tensor` of dtype `float32` and shape `[batch_size, padded_length, - d0 + d1]`. - - Raises: - ValueError: If `sequence_input` does not have rank 3 or `context_input` does - not have rank 2. - """ - seq_rank_check = check_ops.assert_rank( - sequence_input, - 3, - message='sequence_input must have rank 3', - data=[array_ops.shape(sequence_input)]) - seq_type_check = check_ops.assert_type( - sequence_input, - dtypes.float32, - message='sequence_input must have dtype float32; got {}.'.format( - sequence_input.dtype)) - ctx_rank_check = check_ops.assert_rank( - context_input, - 2, - message='context_input must have rank 2', - data=[array_ops.shape(context_input)]) - ctx_type_check = check_ops.assert_type( - context_input, - dtypes.float32, - message='context_input must have dtype float32; got {}.'.format( - context_input.dtype)) - with ops.control_dependencies( - [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]): - padded_length = array_ops.shape(sequence_input)[1] - tiled_context_input = array_ops.tile( - array_ops.expand_dims(context_input, 1), - array_ops.concat([[1], [padded_length], [1]], 0)) - return array_ops.concat([sequence_input, tiled_context_input], 2) - - def _select_last_activations(activations, sequence_lengths): """Selects the nth set of activations for each n in `sequence_length`. @@ -222,8 +172,8 @@ def _rnn_logit_fn_builder(output_units, rnn_cell_fn, sequence_feature_columns, context_input = feature_column_lib.input_layer( features=features, feature_columns=context_feature_columns) - sequence_input = _concatenate_context_input(sequence_input, - context_input) + sequence_input = seq_fc.concatenate_context_input( + context_input, sequence_input) cell = rnn_cell_fn(mode) # Ignore output state. diff --git a/tensorflow/contrib/estimator/python/estimator/rnn_test.py b/tensorflow/contrib/estimator/python/estimator/rnn_test.py index 1aebed348dcacf8fbe90421bdc7ff25f5b7bcc4a..89506ee6615cd838b0fe651e13eb3e7dd35d2aef 100644 --- a/tensorflow/contrib/estimator/python/estimator/rnn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/rnn_test.py @@ -25,12 +25,12 @@ import tempfile import numpy as np import six -from tensorflow.contrib.data.python.ops import readers from tensorflow.contrib.estimator.python.estimator import head as head_lib from tensorflow.contrib.estimator.python.estimator import rnn from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as seq_fc from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.experimental.ops import readers from tensorflow.python.estimator import model_fn from tensorflow.python.estimator.canned import metric_keys from tensorflow.python.estimator.canned import parsing_utils diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index 510f29250899eb6ab4062f772f4c72dbdc70c2dd..e344d7a23b55134612aab430b50cf065bd1095e4 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -154,8 +154,6 @@ tf_py_test( ], tags = [ "no_pip", # b/38283730 - "noasan", # b/116875897 - "nomsan", "notsan", # Flaky: b/30756419 ], ) @@ -179,11 +177,7 @@ tf_py_test( "//tensorflow/python:random_seed", "//tensorflow/python:variables", ], - tags = [ - "noasan", # b/116875897 - "nomsan", - "notsan", # b/62863147 - ], + tags = ["notsan"], # b/62863147 ) py_library( @@ -282,7 +276,6 @@ tf_py_test( "manual", "noasan", # times out b/63678675 "nomsan", - "notsan", # b/116875897 ], ) diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops.py b/tensorflow/contrib/factorization/python/ops/gmm_ops.py index e076631bc16fd379a2ad31af9055a7388d98c7ca..d365ad111760247fc18b730657390f07ba6b865e 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_ops.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_ops.py @@ -154,10 +154,10 @@ class GmmAlgorithm(object): def _create_variables(self): """Initializes GMM algorithm.""" init_value = array_ops.constant([], dtype=dtypes.float32) - self._means = variables.Variable(init_value, - name=self.CLUSTERS_VARIABLE, - validate_shape=False) - self._covs = variables.Variable( + self._means = variables.VariableV1(init_value, + name=self.CLUSTERS_VARIABLE, + validate_shape=False) + self._covs = variables.VariableV1( init_value, name=self.CLUSTERS_COVS_VARIABLE, validate_shape=False) # Mixture weights, representing the probability that a randomly # selected unobservable data (in EM terms) was generated by component k. @@ -165,9 +165,9 @@ class GmmAlgorithm(object): array_ops.tile([1.0 / self._num_classes], [self._num_classes]), name=self.CLUSTERS_WEIGHT, validate_shape=False) - self._cluster_centers_initialized = variables.Variable(False, - dtype=dtypes.bool, - name='initialized') + self._cluster_centers_initialized = variables.VariableV1(False, + dtype=dtypes.bool, + name='initialized') def _initialize_variables(self, data, initial_means=None): """Initializes variables. diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py index 9bdbd050152261daff803e6e71abea93406402ed..75d577f42958d97ccb2632798e86ae059c399cb4 100644 --- a/tensorflow/contrib/factorization/python/ops/wals_test.py +++ b/tensorflow/contrib/factorization/python/ops/wals_test.py @@ -420,13 +420,13 @@ class WALSMatrixFactorizationUnsupportedTest(test.TestCase): class SweepHookTest(test.TestCase): def test_sweeps(self): - is_row_sweep_var = variables.Variable(True) - is_sweep_done_var = variables.Variable(False) - init_done = variables.Variable(False) - row_prep_done = variables.Variable(False) - col_prep_done = variables.Variable(False) - row_train_done = variables.Variable(False) - col_train_done = variables.Variable(False) + is_row_sweep_var = variables.VariableV1(True) + is_sweep_done_var = variables.VariableV1(False) + init_done = variables.VariableV1(False) + row_prep_done = variables.VariableV1(False) + col_prep_done = variables.VariableV1(False) + row_train_done = variables.VariableV1(False) + col_train_done = variables.VariableV1(False) init_op = state_ops.assign(init_done, True) row_prep_op = state_ops.assign(row_prep_done, True) @@ -486,7 +486,7 @@ class StopAtSweepHookTest(test.TestCase): def test_stop(self): hook = wals_lib._StopAtSweepHook(last_sweep=10) - completed_sweeps = variables.Variable( + completed_sweeps = variables.VariableV1( 8, name=wals_lib.WALSMatrixFactorization.COMPLETED_SWEEPS) train_op = state_ops.assign_add(completed_sweeps, 1) hook.begin() diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index aab7d0c9e8874269bfa5f33193b0dc0ba4bbc9cd..a926ffd5982116a21dc7a0fd1ff957d4ecc6bf94 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -27,6 +27,7 @@ py_library( "//tensorflow/python:check_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:sparse_ops", "//tensorflow/python:tensor_shape", @@ -46,9 +47,29 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:parsing_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", "//tensorflow/python/feature_column", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "sequence_feature_column_integration_test", + srcs = ["python/feature_column/sequence_feature_column_integration_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":sequence_feature_column", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python/feature_column", + "//tensorflow/python/keras:layers", ], ) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py index 05bcdac2caa77062f9a8a44a948d2897b439ea1f..dd6da35ed009c07ad3819e7860a283c7837c1f83 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -33,7 +33,6 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variable_scope # pylint: disable=protected-access -# TODO(b/73827486): Support SequenceExample. def sequence_input_layer( @@ -110,6 +109,7 @@ def sequence_input_layer( output_tensors = [] sequence_lengths = [] ordered_columns = [] + for column in sorted(feature_columns, key=lambda x: x.name): ordered_columns.append(column) with variable_scope.variable_scope( @@ -121,17 +121,67 @@ def sequence_input_layer( # Flattens the final dimension to produce a 3D Tensor. num_elements = column._variable_shape.num_elements() shape = array_ops.shape(dense_tensor) + target_shape = [shape[0], shape[1], num_elements] output_tensors.append( - array_ops.reshape( - dense_tensor, - shape=array_ops.concat([shape[:2], [num_elements]], axis=0))) + array_ops.reshape(dense_tensor, shape=target_shape)) sequence_lengths.append(sequence_length) + fc._verify_static_batch_size_equality(output_tensors, ordered_columns) fc._verify_static_batch_size_equality(sequence_lengths, ordered_columns) sequence_length = _assert_all_equal_and_return(sequence_lengths) + return array_ops.concat(output_tensors, -1), sequence_length +def concatenate_context_input(context_input, sequence_input): + """Replicates `context_input` across all timesteps of `sequence_input`. + + Expands dimension 1 of `context_input` then tiles it `sequence_length` times. + This value is appended to `sequence_input` on dimension 2 and the result is + returned. + + Args: + context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`. + sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size, + padded_length, d0]`. + + Returns: + A `Tensor` of dtype `float32` and shape `[batch_size, padded_length, + d0 + d1]`. + + Raises: + ValueError: If `sequence_input` does not have rank 3 or `context_input` does + not have rank 2. + """ + seq_rank_check = check_ops.assert_rank( + sequence_input, + 3, + message='sequence_input must have rank 3', + data=[array_ops.shape(sequence_input)]) + seq_type_check = check_ops.assert_type( + sequence_input, + dtypes.float32, + message='sequence_input must have dtype float32; got {}.'.format( + sequence_input.dtype)) + ctx_rank_check = check_ops.assert_rank( + context_input, + 2, + message='context_input must have rank 2', + data=[array_ops.shape(context_input)]) + ctx_type_check = check_ops.assert_type( + context_input, + dtypes.float32, + message='context_input must have dtype float32; got {}.'.format( + context_input.dtype)) + with ops.control_dependencies( + [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]): + padded_length = array_ops.shape(sequence_input)[1] + tiled_context_input = array_ops.tile( + array_ops.expand_dims(context_input, 1), + array_ops.concat([[1], [padded_length], [1]], 0)) + return array_ops.concat([sequence_input, tiled_context_input], 2) + + def sequence_categorical_column_with_identity( key, num_buckets, default_value=None): """Returns a feature column that represents sequences of integers. @@ -453,9 +503,17 @@ class _SequenceNumericColumn( [array_ops.shape(dense_tensor)[:1], [-1], self._variable_shape], axis=0) dense_tensor = array_ops.reshape(dense_tensor, shape=dense_shape) - sequence_length = fc._sequence_length_from_sparse_tensor( - sp_tensor, num_elements=self._variable_shape.num_elements()) + + # Get the number of timesteps per example + # For the 2D case, the raw values are grouped according to num_elements; + # for the 3D case, the grouping happens in the third dimension, and + # sequence length is not affected. + num_elements = (self._variable_shape.num_elements() + if sp_tensor.shape.ndims == 2 else 1) + seq_length = fc._sequence_length_from_sparse_tensor( + sp_tensor, num_elements=num_elements) + return fc._SequenceDenseColumn.TensorSequenceLengthPair( - dense_tensor=dense_tensor, sequence_length=sequence_length) + dense_tensor=dense_tensor, sequence_length=seq_length) # pylint: enable=protected-access diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d8ca363627eace15e039679545366648df174c33 --- /dev/null +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py @@ -0,0 +1,280 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Integration test for sequence feature columns with SequenceExamples.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import string +import tempfile + +from google.protobuf import text_format + +from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.keras.layers import recurrent +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class SequenceFeatureColumnIntegrationTest(test.TestCase): + + def _make_sequence_example(self): + example = example_pb2.SequenceExample() + example.context.feature['int_ctx'].int64_list.value.extend([5]) + example.context.feature['float_ctx'].float_list.value.extend([123.6]) + for val in range(0, 10, 2): + feat = feature_pb2.Feature() + feat.int64_list.value.extend([val] * val) + example.feature_lists.feature_list['int_list'].feature.extend([feat]) + for val in range(1, 11, 2): + feat = feature_pb2.Feature() + feat.bytes_list.value.extend([compat.as_bytes(str(val))] * val) + example.feature_lists.feature_list['str_list'].feature.extend([feat]) + + return example + + def _build_feature_columns(self): + col = fc.categorical_column_with_identity( + 'int_ctx', num_buckets=100) + ctx_cols = [ + fc.embedding_column(col, dimension=10), + fc.numeric_column('float_ctx')] + + identity_col = sfc.sequence_categorical_column_with_identity( + 'int_list', num_buckets=10) + bucket_col = sfc.sequence_categorical_column_with_hash_bucket( + 'bytes_list', hash_bucket_size=100) + seq_cols = [ + fc.embedding_column(identity_col, dimension=10), + fc.embedding_column(bucket_col, dimension=20)] + + return ctx_cols, seq_cols + + def test_sequence_example_into_input_layer(self): + examples = [_make_sequence_example().SerializeToString()] * 100 + ctx_cols, seq_cols = self._build_feature_columns() + + def _parse_example(example): + ctx, seq = parsing_ops.parse_single_sequence_example( + example, + context_features=fc.make_parse_example_spec(ctx_cols), + sequence_features=fc.make_parse_example_spec(seq_cols)) + ctx.update(seq) + return ctx + + ds = dataset_ops.Dataset.from_tensor_slices(examples) + ds = ds.map(_parse_example) + ds = ds.batch(20) + + # Test on a single batch + features = ds.make_one_shot_iterator().get_next() + + # Tile the context features across the sequence features + seq_layer, _ = sfc.sequence_input_layer(features, seq_cols) + ctx_layer = fc.input_layer(features, ctx_cols) + input_layer = sfc.concatenate_context_input(ctx_layer, seq_layer) + + rnn_layer = recurrent.RNN(recurrent.SimpleRNNCell(10)) + output = rnn_layer(input_layer) + + with self.cached_session() as sess: + sess.run(variables.global_variables_initializer()) + features_r = sess.run(features) + self.assertAllEqual(features_r['int_list'].dense_shape, [20, 3, 6]) + + output_r = sess.run(output) + self.assertAllEqual(output_r.shape, [20, 10]) + + +class SequenceExampleParsingTest(test.TestCase): + + def test_seq_ex_in_sequence_categorical_column_with_identity(self): + self._test_parsed_sequence_example( + 'int_list', sfc.sequence_categorical_column_with_identity, + 10, [3, 6], [2, 4, 6]) + + def test_seq_ex_in_sequence_categorical_column_with_hash_bucket(self): + self._test_parsed_sequence_example( + 'bytes_list', sfc.sequence_categorical_column_with_hash_bucket, + 10, [3, 4], [compat.as_bytes(x) for x in 'acg']) + + def test_seq_ex_in_sequence_categorical_column_with_vocabulary_list(self): + self._test_parsed_sequence_example( + 'bytes_list', sfc.sequence_categorical_column_with_vocabulary_list, + list(string.ascii_lowercase), [3, 4], + [compat.as_bytes(x) for x in 'acg']) + + def test_seq_ex_in_sequence_categorical_column_with_vocabulary_file(self): + _, fname = tempfile.mkstemp() + with open(fname, 'w') as f: + f.write(string.ascii_lowercase) + self._test_parsed_sequence_example( + 'bytes_list', sfc.sequence_categorical_column_with_vocabulary_file, + fname, [3, 4], [compat.as_bytes(x) for x in 'acg']) + + def _test_parsed_sequence_example( + self, col_name, col_fn, col_arg, shape, values): + """Helper function to check that each FeatureColumn parses correctly. + + Args: + col_name: string, name to give to the feature column. Should match + the name that the column will parse out of the features dict. + col_fn: function used to create the feature column. For example, + sequence_numeric_column. + col_arg: second arg that the target feature column is expecting. + shape: the expected dense_shape of the feature after parsing into + a SparseTensor. + values: the expected values at index [0, 2, 6] of the feature + after parsing into a SparseTensor. + """ + example = _make_sequence_example() + columns = [ + fc.categorical_column_with_identity('int_ctx', num_buckets=100), + fc.numeric_column('float_ctx'), + col_fn(col_name, col_arg) + ] + context, seq_features = parsing_ops.parse_single_sequence_example( + example.SerializeToString(), + context_features=fc.make_parse_example_spec(columns[:2]), + sequence_features=fc.make_parse_example_spec(columns[2:])) + + with self.cached_session() as sess: + ctx_result, seq_result = sess.run([context, seq_features]) + self.assertEqual(list(seq_result[col_name].dense_shape), shape) + self.assertEqual( + list(seq_result[col_name].values[[0, 2, 6]]), values) + self.assertEqual(list(ctx_result['int_ctx'].dense_shape), [1]) + self.assertEqual(ctx_result['int_ctx'].values[0], 5) + self.assertEqual(list(ctx_result['float_ctx'].shape), [1]) + self.assertAlmostEqual(ctx_result['float_ctx'][0], 123.6, places=1) + + +_SEQ_EX_PROTO = """ +context { + feature { + key: "float_ctx" + value { + float_list { + value: 123.6 + } + } + } + feature { + key: "int_ctx" + value { + int64_list { + value: 5 + } + } + } +} +feature_lists { + feature_list { + key: "bytes_list" + value { + feature { + bytes_list { + value: "a" + } + } + feature { + bytes_list { + value: "b" + value: "c" + } + } + feature { + bytes_list { + value: "d" + value: "e" + value: "f" + value: "g" + } + } + } + } + feature_list { + key: "float_list" + value { + feature { + float_list { + value: 1.0 + } + } + feature { + float_list { + value: 3.0 + value: 3.0 + value: 3.0 + } + } + feature { + float_list { + value: 5.0 + value: 5.0 + value: 5.0 + value: 5.0 + value: 5.0 + } + } + } + } + feature_list { + key: "int_list" + value { + feature { + int64_list { + value: 2 + value: 2 + } + } + feature { + int64_list { + value: 4 + value: 4 + value: 4 + value: 4 + } + } + feature { + int64_list { + value: 6 + value: 6 + value: 6 + value: 6 + value: 6 + value: 6 + } + } + } + } +} +""" + + +def _make_sequence_example(): + example = example_pb2.SequenceExample() + return text_format.Parse(_SEQ_EX_PROTO, example) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py index 45d7b740462ca21139e2e93e34b43668f1e08a94..2163af0b43864c96483df529f07881f2f985a80e 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import os +from absl.testing import parameterized import numpy as np from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc @@ -28,28 +29,64 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test from tensorflow.python.training import monitored_session -class SequenceInputLayerTest(test.TestCase): +class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args_a': { + # example 0, ids [2] + # example 1, ids [0, 1] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (2, 0, 1), + 'dense_shape': (2, 2)}, + 'sparse_input_args_b': { + # example 0, ids [1] + # example 1, ids [2, 0] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (1, 2, 0), + 'dense_shape': (2, 2)}, + 'expected_input_layer': [ + # example 0, ids_a [2], ids_b [1] + [[5., 6., 14., 15., 16.], [0., 0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [2, 0] + [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]],], + 'expected_sequence_length': [1, 2]}, + {'testcase_name': '3D', + 'sparse_input_args_a': { + # feature 0, ids [[2], [0, 1]] + # feature 1, ids [[0, 0], [1]] + 'indices': ( + (0, 0, 0), (0, 1, 0), (0, 1, 1), + (1, 0, 0), (1, 0, 1), (1, 1, 0)), + 'values': (2, 0, 1, 0, 0, 1), + 'dense_shape': (2, 2, 2)}, + 'sparse_input_args_b': { + # feature 0, ids [[1, 1], [1]] + # feature 1, ids [[2], [0]] + 'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)), + 'values': (1, 1, 1, 2, 0), + 'dense_shape': (2, 2, 2)}, + 'expected_input_layer': [ + # feature 0, [a: 2, -, b: 1, 1], [a: 0, 1, b: 1, -] + [[5., 6., 14., 15., 16.], [2., 3., 14., 15., 16.]], + # feature 1, [a: 0, 0, b: 2, -], [a: 1, -, b: 0, -] + [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]]], + 'expected_sequence_length': [2, 2]}, + ) + def test_embedding_column( + self, sparse_input_args_a, sparse_input_args_b, expected_input_layer, + expected_sequence_length): - def test_embedding_column(self): + sparse_input_a = sparse_tensor.SparseTensorValue(**sparse_input_args_a) + sparse_input_b = sparse_tensor.SparseTensorValue(**sparse_input_args_b) vocabulary_size = 3 - sparse_input_a = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - indices=((0, 0), (1, 0), (1, 1)), - values=(2, 0, 1), - dense_shape=(2, 2)) - sparse_input_b = sparse_tensor.SparseTensorValue( - # example 0, ids [1] - # example 1, ids [2, 0] - indices=((0, 0), (1, 0), (1, 1)), - values=(1, 2, 0), - dense_shape=(2, 2)) - embedding_dimension_a = 2 embedding_values_a = ( (1., 2.), # id 0 @@ -70,14 +107,6 @@ class SequenceInputLayerTest(test.TestCase): return embedding_values return _initializer - expected_input_layer = [ - # example 0, ids_a [2], ids_b [1] - [[5., 6., 14., 15., 16.], [0., 0., 0., 0., 0.]], - # example 1, ids_a [0, 1], ids_b [2, 0] - [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]], - ] - expected_sequence_length = [1, 2] - categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) embedding_column_a = fc.embedding_column( @@ -233,29 +262,56 @@ class SequenceInputLayerTest(test.TestCase): }, feature_columns=shared_embedding_columns) - def test_indicator_column(self): + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args_a': { + # example 0, ids [2] + # example 1, ids [0, 1] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (2, 0, 1), + 'dense_shape': (2, 2)}, + 'sparse_input_args_b': { + # example 0, ids [1] + # example 1, ids [1, 0] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (1, 1, 0), + 'dense_shape': (2, 2)}, + 'expected_input_layer': [ + # example 0, ids_a [2], ids_b [1] + [[0., 0., 1., 0., 1.], [0., 0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [1, 0] + [[1., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]]], + 'expected_sequence_length': [1, 2]}, + {'testcase_name': '3D', + 'sparse_input_args_a': { + # feature 0, ids [[2], [0, 1]] + # feature 1, ids [[0, 0], [1]] + 'indices': ( + (0, 0, 0), (0, 1, 0), (0, 1, 1), + (1, 0, 0), (1, 0, 1), (1, 1, 0)), + 'values': (2, 0, 1, 0, 0, 1), + 'dense_shape': (2, 2, 2)}, + 'sparse_input_args_b': { + # feature 0, ids [[1, 1], [1]] + # feature 1, ids [[1], [0]] + 'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)), + 'values': (1, 1, 1, 1, 0), + 'dense_shape': (2, 2, 2)}, + 'expected_input_layer': [ + # feature 0, [a: 2, -, b: 1, 1], [a: 0, 1, b: 1, -] + [[0., 0., 1., 0., 2.], [1., 1., 0., 0., 1.]], + # feature 1, [a: 0, 0, b: 1, -], [a: 1, -, b: 0, -] + [[2., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]]], + 'expected_sequence_length': [2, 2]}, + ) + def test_indicator_column( + self, sparse_input_args_a, sparse_input_args_b, expected_input_layer, + expected_sequence_length): + sparse_input_a = sparse_tensor.SparseTensorValue(**sparse_input_args_a) + sparse_input_b = sparse_tensor.SparseTensorValue(**sparse_input_args_b) + vocabulary_size_a = 3 - sparse_input_a = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - indices=((0, 0), (1, 0), (1, 1)), - values=(2, 0, 1), - dense_shape=(2, 2)) vocabulary_size_b = 2 - sparse_input_b = sparse_tensor.SparseTensorValue( - # example 0, ids [1] - # example 1, ids [1, 0] - indices=((0, 0), (1, 0), (1, 1)), - values=(1, 1, 0), - dense_shape=(2, 2)) - - expected_input_layer = [ - # example 0, ids_a [2], ids_b [1] - [[0., 0., 1., 0., 1.], [0., 0., 0., 0., 0.]], - # example 1, ids_a [0, 1], ids_b [1, 0] - [[1., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]], - ] - expected_sequence_length = [1, 2] categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size_a) @@ -298,18 +354,34 @@ class SequenceInputLayerTest(test.TestCase): features={'aaa': sparse_input}, feature_columns=[indicator_column_a]) - def test_numeric_column(self): - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[0.], [1]] - # example 1, [[10.]] - indices=((0, 0), (0, 1), (1, 0)), - values=(0., 1., 10.), - dense_shape=(2, 2)) - expected_input_layer = [ - [[0.], [1.]], - [[10.], [0.]], - ] - expected_sequence_length = [2, 1] + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args': { + # example 0, values [0., 1] + # example 1, [10.] + 'indices': ((0, 0), (0, 1), (1, 0)), + 'values': (0., 1., 10.), + 'dense_shape': (2, 2)}, + 'expected_input_layer': [ + [[0.], [1.]], + [[10.], [0.]]], + 'expected_sequence_length': [2, 1]}, + {'testcase_name': '3D', + 'sparse_input_args': { + # feature 0, ids [[20, 3], [5]] + # feature 1, ids [[3], [8]] + 'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)), + 'values': (20, 3, 5., 3., 8.), + 'dense_shape': (2, 2, 2)}, + 'expected_input_layer': [ + [[20.], [3.], [5.], [0.]], + [[3.], [0.], [8.], [0.]]], + 'expected_sequence_length': [2, 2]}, + ) + def test_numeric_column( + self, sparse_input_args, expected_input_layer, expected_sequence_length): + sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) + numeric_column = sfc.sequence_numeric_column('aaa') input_layer, sequence_length = sfc.sequence_input_layer( @@ -321,21 +393,40 @@ class SequenceInputLayerTest(test.TestCase): self.assertAllEqual( expected_sequence_length, sequence_length.eval(session=sess)) - def test_numeric_column_multi_dim(self): + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args': { + # example 0, values [0., 1., 2., 3., 4., 5., 6., 7.] + # example 1, [10., 11., 12., 13.] + 'indices': ((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), + (0, 7), (1, 0), (1, 1), (1, 2), (1, 3)), + 'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + 'dense_shape': (2, 8)}, + 'expected_input_layer': [ + # The output of numeric_column._get_dense_tensor should be flattened. + [[0., 1., 2., 3.], [4., 5., 6., 7.]], + [[10., 11., 12., 13.], [0., 0., 0., 0.]]], + 'expected_sequence_length': [2, 1]}, + {'testcase_name': '3D', + 'sparse_input_args': { + # example 0, values [[0., 1., 2., 3.]], [[4., 5., 6., 7.]] + # example 1, [[10., 11., 12., 13.], []] + 'indices': ((0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3), + (0, 1, 0), (0, 1, 1), (0, 1, 2), (0, 1, 3), + (1, 0, 0), (1, 0, 1), (1, 0, 2), (1, 0, 3)), + 'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + 'dense_shape': (2, 2, 4)}, + 'expected_input_layer': [ + # The output of numeric_column._get_dense_tensor should be flattened. + [[0., 1., 2., 3.], [4., 5., 6., 7.]], + [[10., 11., 12., 13.], [0., 0., 0., 0.]]], + 'expected_sequence_length': [2, 1]}, + ) + def test_numeric_column_multi_dim( + self, sparse_input_args, expected_input_layer, expected_sequence_length): """Tests sequence_input_layer for multi-dimensional numeric_column.""" - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] - # example 1, [[[10., 11.], [12., 13.]]] - indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), - (1, 0), (1, 1), (1, 2), (1, 3)), - values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), - dense_shape=(2, 8)) - # The output of numeric_column._get_dense_tensor should be flattened. - expected_input_layer = [ - [[0., 1., 2., 3.], [4., 5., 6., 7.]], - [[10., 11., 12., 13.], [0., 0., 0., 0.]], - ] - expected_sequence_length = [2, 1] + sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) + numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) input_layer, sequence_length = sfc.sequence_input_layer( @@ -377,6 +468,138 @@ class SequenceInputLayerTest(test.TestCase): r'\[y \(sequence_input_layer/bbb/sequence_length:0\) = \] \[1 1\]'): sess.run(sequence_length) + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args': { + # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] + # example 1, [[[10., 11.], [12., 13.]]] + 'indices': ((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), + (0, 7), (1, 0), (1, 1), (1, 2), (1, 3)), + 'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + 'dense_shape': (2, 8)}, + 'expected_shape': [2, 2, 4]}, + {'testcase_name': '3D', + 'sparse_input_args': { + # example 0, values [[0., 1., 2., 3.]], [[4., 5., 6., 7.]] + # example 1, [[10., 11., 12., 13.], []] + 'indices': ((0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3), + (0, 1, 0), (0, 1, 1), (0, 1, 2), (0, 1, 2), + (1, 0, 0), (1, 0, 1), (1, 0, 2), (1, 0, 3)), + 'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + 'dense_shape': (2, 2, 4)}, + 'expected_shape': [2, 2, 4]}, + ) + def test_static_shape_from_tensors_numeric( + self, sparse_input_args, expected_shape): + """Tests that we return a known static shape when we have one.""" + sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) + numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) + + input_layer, _ = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[numeric_column]) + shape = input_layer.get_shape() + self.assertEqual(shape, expected_shape) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + 'indices': ((0, 0), (1, 0), (1, 1), (3, 0)), + 'values': (2, 0, 1, 1), + 'dense_shape': (4, 2)}, + 'expected_shape': [4, 2, 3]}, + {'testcase_name': '3D', + 'sparse_input_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + # example 2, ids [] + # example 3, ids [[1], [0, 2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0), + (3, 0, 0), (3, 1, 0), (3, 1, 1)), + 'values': (2, 0, 1, 2, 1, 0, 2), + 'dense_shape': (4, 2, 2)}, + 'expected_shape': [4, 2, 3]} + ) + def test_static_shape_from_tensors_indicator( + self, sparse_input_args, expected_shape): + """Tests that we return a known static shape when we have one.""" + sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=3) + indicator_column = fc.indicator_column(categorical_column) + + input_layer, _ = sfc.sequence_input_layer( + features={'aaa': sparse_input}, feature_columns=[indicator_column]) + shape = input_layer.get_shape() + self.assertEqual(shape, expected_shape) + + +class ConcatenateContextInputTest(test.TestCase, parameterized.TestCase): + """Tests the utility fn concatenate_context_input.""" + + def test_concatenate_context_input(self): + seq_input = ops.convert_to_tensor(np.arange(12).reshape(2, 3, 2)) + context_input = ops.convert_to_tensor(np.arange(10).reshape(2, 5)) + seq_input = math_ops.cast(seq_input, dtype=dtypes.float32) + context_input = math_ops.cast(context_input, dtype=dtypes.float32) + input_layer = sfc.concatenate_context_input(context_input, seq_input) + + expected = np.array([ + [[0, 1, 0, 1, 2, 3, 4], [2, 3, 0, 1, 2, 3, 4], [4, 5, 0, 1, 2, 3, 4]], + [[6, 7, 5, 6, 7, 8, 9], [8, 9, 5, 6, 7, 8, 9], [10, 11, 5, 6, 7, 8, 9]] + ], dtype=np.float32) + with monitored_session.MonitoredSession() as sess: + output = sess.run(input_layer) + self.assertAllEqual(expected, output) + + @parameterized.named_parameters( + {'testcase_name': 'rank_lt_3', + 'seq_input_arg': np.arange(100).reshape(10, 10)}, + {'testcase_name': 'rank_gt_3', + 'seq_input_arg': np.arange(100).reshape(5, 5, 2, 2)} + ) + def test_sequence_input_throws_error(self, seq_input_arg): + seq_input = ops.convert_to_tensor(seq_input_arg) + context_input = ops.convert_to_tensor(np.arange(100).reshape(10, 10)) + seq_input = math_ops.cast(seq_input, dtype=dtypes.float32) + context_input = math_ops.cast(context_input, dtype=dtypes.float32) + with self.assertRaisesRegexp(ValueError, 'sequence_input must have rank 3'): + sfc.concatenate_context_input(context_input, seq_input) + + @parameterized.named_parameters( + {'testcase_name': 'rank_lt_2', + 'context_input_arg': np.arange(100)}, + {'testcase_name': 'rank_gt_2', + 'context_input_arg': np.arange(100).reshape(5, 5, 4)} + ) + def test_context_input_throws_error(self, context_input_arg): + context_input = ops.convert_to_tensor(context_input_arg) + seq_input = ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4)) + seq_input = math_ops.cast(seq_input, dtype=dtypes.float32) + context_input = math_ops.cast(context_input, dtype=dtypes.float32) + with self.assertRaisesRegexp(ValueError, 'context_input must have rank 2'): + sfc.concatenate_context_input(context_input, seq_input) + + def test_integer_seq_input_throws_error(self): + seq_input = ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4)) + context_input = ops.convert_to_tensor(np.arange(100).reshape(10, 10)) + context_input = math_ops.cast(context_input, dtype=dtypes.float32) + with self.assertRaisesRegexp( + TypeError, 'sequence_input must have dtype float32'): + sfc.concatenate_context_input(context_input, seq_input) + + def test_integer_context_input_throws_error(self): + seq_input = ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4)) + context_input = ops.convert_to_tensor(np.arange(100).reshape(10, 10)) + seq_input = math_ops.cast(seq_input, dtype=dtypes.float32) + with self.assertRaisesRegexp( + TypeError, 'context_input must have dtype float32'): + sfc.concatenate_context_input(context_input, seq_input) + class InputLayerTest(test.TestCase): """Tests input_layer with sequence feature columns.""" @@ -443,75 +666,83 @@ def _assert_sparse_tensor_indices_shape(test_case, expected, actual): test_case.assertAllEqual(expected.dense_shape, actual.dense_shape) -class SequenceCategoricalColumnWithIdentityTest(test.TestCase): - - def test_get_sparse_tensors(self): - column = sfc.sequence_categorical_column_with_identity( - 'aaa', num_buckets=3) - inputs = sparse_tensor.SparseTensorValue( - indices=((0, 0), (1, 0), (1, 1)), - values=(1, 2, 0), - dense_shape=(2, 2)) - expected_sparse_ids = sparse_tensor.SparseTensorValue( - indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), - values=np.array((1, 2, 0), dtype=np.int64), - dense_shape=(2, 2, 1)) +class SequenceCategoricalColumnWithIdentityTest( + test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (1, 2, 0), + 'dense_shape': (2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 0), (1, 0, 0), (1, 1, 0)), + 'values': np.array((1, 2, 0), dtype=np.int64), + 'dense_shape': (2, 2, 1)}}, + {'testcase_name': '3D', + 'inputs_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': (6, 7, 8), + 'dense_shape': (2, 2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': (6, 7, 8), + 'dense_shape': (2, 2, 2)}} + ) + def test_get_sparse_tensors(self, inputs_args, expected_args): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + expected = sparse_tensor.SparseTensorValue(**expected_args) + column = sfc.sequence_categorical_column_with_identity('aaa', num_buckets=9) id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: _assert_sparse_tensor_value( - self, - expected_sparse_ids, - id_weight_pair.id_tensor.eval(session=sess)) - - def test_get_sparse_tensors_inputs3d(self): - """Tests _get_sparse_tensors when the input is already 3D Tensor.""" - column = sfc.sequence_categorical_column_with_identity( - 'aaa', num_buckets=3) - inputs = sparse_tensor.SparseTensorValue( - indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), - values=(1, 2, 0), - dense_shape=(2, 2, 1)) - - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - r'Column aaa expected ID tensor of rank 2\.\s*' - r'id_tensor shape:\s*\[2 2 1\]'): - id_weight_pair = column._get_sparse_tensors( - _LazyBuilder({'aaa': inputs})) - with monitored_session.MonitoredSession() as sess: - id_weight_pair.id_tensor.eval(session=sess) - - -class SequenceCategoricalColumnWithHashBucketTest(test.TestCase): - - def test_get_sparse_tensors(self): + self, expected, id_weight_pair.id_tensor.eval(session=sess)) + + +class SequenceCategoricalColumnWithHashBucketTest( + test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': ('omar', 'stringer', 'marlo'), + 'dense_shape': (2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 0), (1, 0, 0), (1, 1, 0)), + # Ignored to avoid hash dependence in test. + 'values': np.array((0, 0, 0), dtype=np.int64), + 'dense_shape': (2, 2, 1)}}, + {'testcase_name': '3D', + 'inputs_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': ('omar', 'stringer', 'marlo'), + 'dense_shape': (2, 2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + # Ignored to avoid hash dependence in test. + 'values': np.array((0, 0, 0), dtype=np.int64), + 'dense_shape': (2, 2, 2)}} + ) + def test_get_sparse_tensors(self, inputs_args, expected_args): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + expected = sparse_tensor.SparseTensorValue(**expected_args) column = sfc.sequence_categorical_column_with_hash_bucket( 'aaa', hash_bucket_size=10) - inputs = sparse_tensor.SparseTensorValue( - indices=((0, 0), (1, 0), (1, 1)), - values=('omar', 'stringer', 'marlo'), - dense_shape=(2, 2)) - - expected_sparse_ids = sparse_tensor.SparseTensorValue( - indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), - # Ignored to avoid hash dependence in test. - values=np.array((0, 0, 0), dtype=np.int64), - dense_shape=(2, 2, 1)) id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: _assert_sparse_tensor_indices_shape( - self, - expected_sparse_ids, - id_weight_pair.id_tensor.eval(session=sess)) + self, expected, id_weight_pair.id_tensor.eval(session=sess)) -class SequenceCategoricalColumnWithVocabularyFileTest(test.TestCase): +class SequenceCategoricalColumnWithVocabularyFileTest( + test.TestCase, parameterized.TestCase): def _write_vocab(self, vocab_strings, file_name): vocab_file = os.path.join(self.get_temp_dir(), file_name) @@ -527,68 +758,152 @@ class SequenceCategoricalColumnWithVocabularyFileTest(test.TestCase): 'wire_vocabulary.txt') self._wire_vocabulary_size = 3 - def test_get_sparse_tensors(self): + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': ('marlo', 'skywalker', 'omar'), + 'dense_shape': (2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 0), (1, 0, 0), (1, 1, 0)), + 'values': np.array((2, -1, 0), dtype=np.int64), + 'dense_shape': (2, 2, 1)}}, + {'testcase_name': '3D', + 'inputs_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': ('omar', 'skywalker', 'marlo'), + 'dense_shape': (2, 2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': np.array((0, -1, 2), dtype=np.int64), + 'dense_shape': (2, 2, 2)}} + ) + def test_get_sparse_tensors(self, inputs_args, expected_args): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + expected = sparse_tensor.SparseTensorValue(**expected_args) column = sfc.sequence_categorical_column_with_vocabulary_file( key='aaa', vocabulary_file=self._wire_vocabulary_file_name, vocabulary_size=self._wire_vocabulary_size) - inputs = sparse_tensor.SparseTensorValue( - indices=((0, 0), (1, 0), (1, 1)), - values=('marlo', 'skywalker', 'omar'), - dense_shape=(2, 2)) - expected_sparse_ids = sparse_tensor.SparseTensorValue( - indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), - values=np.array((2, -1, 0), dtype=np.int64), - dense_shape=(2, 2, 1)) id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: _assert_sparse_tensor_value( - self, - expected_sparse_ids, - id_weight_pair.id_tensor.eval(session=sess)) - + self, expected, id_weight_pair.id_tensor.eval(session=sess)) -class SequenceCategoricalColumnWithVocabularyListTest(test.TestCase): + def test_get_sparse_tensors_dynamic_zero_length(self): + """Tests _get_sparse_tensors with a dynamic sequence length.""" + inputs = sparse_tensor.SparseTensorValue( + indices=np.zeros((0, 2)), values=[], dense_shape=(2, 0)) + expected = sparse_tensor.SparseTensorValue( + indices=np.zeros((0, 3)), + values=np.array((), dtype=np.int64), + dense_shape=(2, 0, 1)) + column = sfc.sequence_categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size) + input_placeholder_shape = list(inputs.dense_shape) + # Make second dimension (sequence length) dynamic. + input_placeholder_shape[1] = None + input_placeholder = array_ops.sparse_placeholder( + dtypes.string, shape=input_placeholder_shape) + id_weight_pair = column._get_sparse_tensors( + _LazyBuilder({'aaa': input_placeholder})) - def test_get_sparse_tensors(self): + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + result = id_weight_pair.id_tensor.eval( + session=sess, feed_dict={input_placeholder: inputs}) + _assert_sparse_tensor_value( + self, expected, result) + + +class SequenceCategoricalColumnWithVocabularyListTest( + test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': ('marlo', 'skywalker', 'omar'), + 'dense_shape': (2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 0), (1, 0, 0), (1, 1, 0)), + 'values': np.array((2, -1, 0), dtype=np.int64), + 'dense_shape': (2, 2, 1)}}, + {'testcase_name': '3D', + 'inputs_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': ('omar', 'skywalker', 'marlo'), + 'dense_shape': (2, 2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': np.array((0, -1, 2), dtype=np.int64), + 'dense_shape': (2, 2, 2)}} + ) + def test_get_sparse_tensors(self, inputs_args, expected_args): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + expected = sparse_tensor.SparseTensorValue(**expected_args) column = sfc.sequence_categorical_column_with_vocabulary_list( key='aaa', vocabulary_list=('omar', 'stringer', 'marlo')) - inputs = sparse_tensor.SparseTensorValue( - indices=((0, 0), (1, 0), (1, 1)), - values=('marlo', 'skywalker', 'omar'), - dense_shape=(2, 2)) - expected_sparse_ids = sparse_tensor.SparseTensorValue( - indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), - values=np.array((2, -1, 0), dtype=np.int64), - dense_shape=(2, 2, 1)) id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: _assert_sparse_tensor_value( - self, - expected_sparse_ids, - id_weight_pair.id_tensor.eval(session=sess)) - - -class SequenceEmbeddingColumnTest(test.TestCase): - - def test_get_sequence_dense_tensor(self): + self, expected, id_weight_pair.id_tensor.eval(session=sess)) + + +class SequenceEmbeddingColumnTest( + test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + 'indices': ((0, 0), (1, 0), (1, 1), (3, 0)), + 'values': (2, 0, 1, 1), + 'dense_shape': (4, 2)}, + 'expected': [ + # example 0, ids [2] + [[7., 11.], [0., 0.]], + # example 1, ids [0, 1] + [[1., 2.], [3., 5.]], + # example 2, ids [] + [[0., 0.], [0., 0.]], + # example 3, ids [1] + [[3., 5.], [0., 0.]]]}, + {'testcase_name': '3D', + 'inputs_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + # example 2, ids [] + # example 3, ids [[1], [0, 2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0), + (3, 0, 0), (3, 1, 0), (3, 1, 1)), + 'values': (2, 0, 1, 2, 1, 0, 2), + 'dense_shape': (4, 2, 2)}, + 'expected': [ + # example 0, ids [[2]] + [[7., 11.], [0., 0.]], + # example 1, ids [[0, 1], [2]] + [[2, 3.5], [7., 11.]], + # example 2, ids [] + [[0., 0.], [0., 0.]], + # example 3, ids [[1], [0, 2]] + [[3., 5.], [4., 6.5]]]} + ) + def test_get_sequence_dense_tensor(self, inputs_args, expected): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) vocabulary_size = 3 - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - # example 2, ids [] - # example 3, ids [1] - indices=((0, 0), (1, 0), (1, 1), (3, 0)), - values=(2, 0, 1, 1), - dense_shape=(4, 2)) - embedding_dimension = 2 embedding_values = ( (1., 2.), # id 0 @@ -601,17 +916,6 @@ class SequenceEmbeddingColumnTest(test.TestCase): self.assertIsNone(partition_info) return embedding_values - expected_lookups = [ - # example 0, ids [2] - [[7., 11.], [0., 0.]], - # example 1, ids [0, 1] - [[1., 2.], [3., 5.]], - # example 2, ids [] - [[0., 0.], [0., 0.]], - # example 3, ids [1] - [[3., 5.], [0., 0.]], - ] - categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) embedding_column = fc.embedding_column( @@ -619,24 +923,36 @@ class SequenceEmbeddingColumnTest(test.TestCase): initializer=_initializer) embedding_lookup, _ = embedding_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _LazyBuilder({'aaa': inputs})) global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) self.assertItemsEqual( ('embedding_weights:0',), tuple([v.name for v in global_vars])) with monitored_session.MonitoredSession() as sess: self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) - self.assertAllEqual(expected_lookups, embedding_lookup.eval(session=sess)) - - def test_sequence_length(self): + self.assertAllEqual(expected, embedding_lookup.eval(session=sess)) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (2, 0, 1), + 'dense_shape': (2, 2)}, + 'expected_sequence_length': [1, 2]}, + {'testcase_name': '3D', + 'inputs_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), + 'values': (2, 0, 1, 2), + 'dense_shape': (2, 2, 2)}, + 'expected_sequence_length': [1, 2]} + ) + def test_sequence_length(self, inputs_args, expected_sequence_length): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) vocabulary_size = 3 - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - indices=((0, 0), (1, 0), (1, 1)), - values=(2, 0, 1), - dense_shape=(2, 2)) - expected_sequence_length = [1, 2] categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) @@ -644,7 +960,7 @@ class SequenceEmbeddingColumnTest(test.TestCase): categorical_column, dimension=2) _, sequence_length = embedding_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _LazyBuilder({'aaa': inputs})) with monitored_session.MonitoredSession() as sess: sequence_length = sess.run(sequence_length) @@ -855,56 +1171,89 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): expected_sequence_length_b, sequence_length_b.eval(session=sess)) -class SequenceIndicatorColumnTest(test.TestCase): - - def test_get_sequence_dense_tensor(self): +class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + 'indices': ((0, 0), (1, 0), (1, 1), (3, 0)), + 'values': (2, 0, 1, 1), + 'dense_shape': (4, 2)}, + 'expected': [ + # example 0, ids [2] + [[0., 0., 1.], [0., 0., 0.]], + # example 1, ids [0, 1] + [[1., 0., 0.], [0., 1., 0.]], + # example 2, ids [] + [[0., 0., 0.], [0., 0., 0.]], + # example 3, ids [1] + [[0., 1., 0.], [0., 0., 0.]]]}, + {'testcase_name': '3D', + 'inputs_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + # example 2, ids [] + # example 3, ids [[1], [2, 2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0), + (3, 0, 0), (3, 1, 0), (3, 1, 1)), + 'values': (2, 0, 1, 2, 1, 2, 2), + 'dense_shape': (4, 2, 2)}, + 'expected': [ + # example 0, ids [[2]] + [[0., 0., 1.], [0., 0., 0.]], + # example 1, ids [[0, 1], [2]] + [[1., 1., 0.], [0., 0., 1.]], + # example 2, ids [] + [[0., 0., 0.], [0., 0., 0.]], + # example 3, ids [[1], [2, 2]] + [[0., 1., 0.], [0., 0., 2.]]]} + ) + def test_get_sequence_dense_tensor(self, inputs_args, expected): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) vocabulary_size = 3 - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - # example 2, ids [] - # example 3, ids [1] - indices=((0, 0), (1, 0), (1, 1), (3, 0)), - values=(2, 0, 1, 1), - dense_shape=(4, 2)) - - expected_lookups = [ - # example 0, ids [2] - [[0., 0., 1.], [0., 0., 0.]], - # example 1, ids [0, 1] - [[1., 0., 0.], [0., 1., 0.]], - # example 2, ids [] - [[0., 0., 0.], [0., 0., 0.]], - # example 3, ids [1] - [[0., 1., 0.], [0., 0., 0.]], - ] categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) indicator_column = fc.indicator_column(categorical_column) indicator_tensor, _ = indicator_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _LazyBuilder({'aaa': inputs})) with monitored_session.MonitoredSession() as sess: - self.assertAllEqual(expected_lookups, indicator_tensor.eval(session=sess)) - - def test_sequence_length(self): + self.assertAllEqual(expected, indicator_tensor.eval(session=sess)) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (2, 0, 1), + 'dense_shape': (2, 2)}, + 'expected_sequence_length': [1, 2]}, + {'testcase_name': '3D', + 'inputs_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), + 'values': (2, 0, 1, 2), + 'dense_shape': (2, 2, 2)}, + 'expected_sequence_length': [1, 2]} + ) + def test_sequence_length(self, inputs_args, expected_sequence_length): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) vocabulary_size = 3 - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - indices=((0, 0), (1, 0), (1, 1)), - values=(2, 0, 1), - dense_shape=(2, 2)) - expected_sequence_length = [1, 2] categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) indicator_column = fc.indicator_column(categorical_column) _, sequence_length = indicator_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _LazyBuilder({'aaa': inputs})) with monitored_session.MonitoredSession() as sess: sequence_length = sess.run(sequence_length) @@ -938,7 +1287,7 @@ class SequenceIndicatorColumnTest(test.TestCase): expected_sequence_length, sequence_length.eval(session=sess)) -class SequenceNumericColumnTest(test.TestCase): +class SequenceNumericColumnTest(test.TestCase, parameterized.TestCase): def test_defaults(self): a = sfc.sequence_numeric_column('aaa') @@ -971,25 +1320,37 @@ class SequenceNumericColumnTest(test.TestCase): with self.assertRaisesRegexp(TypeError, 'must be a callable'): sfc.sequence_numeric_column('aaa', normalizer_fn='NotACallable') - def test_get_sequence_dense_tensor(self): - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[0.], [1]] - # example 1, [[10.]] - indices=((0, 0), (0, 1), (1, 0)), - values=(0., 1., 10.), - dense_shape=(2, 2)) - expected_dense_tensor = [ - [[0.], [1.]], - [[10.], [0.]], - ] + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + # example 0, values [0., 1] + # example 1, [10.] + 'indices': ((0, 0), (0, 1), (1, 0)), + 'values': (0., 1., 10.), + 'dense_shape': (2, 2)}, + 'expected': [ + [[0.], [1.]], + [[10.], [0.]]]}, + {'testcase_name': '3D', + 'inputs_args': { + # feature 0, ids [[20, 3], [5]] + # feature 1, ids [[3], [8]] + 'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)), + 'values': (20, 3, 5., 3., 8.), + 'dense_shape': (2, 2, 2)}, + 'expected': [ + [[20.], [3.], [5.], [0.]], + [[3.], [0.], [8.], [0.]]]}, + ) + def test_get_sequence_dense_tensor(self, inputs_args, expected): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) numeric_column = sfc.sequence_numeric_column('aaa') dense_tensor, _ = numeric_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _LazyBuilder({'aaa': inputs})) with monitored_session.MonitoredSession() as sess: - self.assertAllEqual( - expected_dense_tensor, dense_tensor.eval(session=sess)) + self.assertAllEqual(expected, dense_tensor.eval(session=sess)) def test_get_sequence_dense_tensor_with_normalizer_fn(self): @@ -1026,41 +1387,35 @@ class SequenceNumericColumnTest(test.TestCase): self.assertAllEqual( expected_dense_tensor, dense_tensor.eval(session=sess)) - def test_get_sequence_dense_tensor_with_shape(self): - """Tests get_sequence_dense_tensor with shape !=(1,).""" - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[0., 1., 2.], [3., 4., 5.]] - # example 1, [[10., 11., 12.]] - indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), - (1, 0), (1, 1), (1, 2)), - values=(0., 1., 2., 3., 4., 5., 10., 11., 12.), - dense_shape=(2, 6)) - expected_dense_tensor = [ - [[0., 1., 2.], [3., 4., 5.]], - [[10., 11., 12.], [0., 0., 0.]], - ] - numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,)) - - dense_tensor, _ = numeric_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) - - with monitored_session.MonitoredSession() as sess: - self.assertAllEqual( - expected_dense_tensor, dense_tensor.eval(session=sess)) - - def test_get_dense_tensor_multi_dim(self): + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args': { + # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] + # example 1, [[[10., 11.], [12., 13.]]] + 'indices': ((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), + (0, 7), (1, 0), (1, 1), (1, 2), (1, 3)), + 'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + 'dense_shape': (2, 8)}, + 'expected_dense_tensor': [ + [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]], + [[[10., 11.], [12., 13.]], [[0., 0.], [0., 0.]]]]}, + {'testcase_name': '3D', + 'sparse_input_args': { + 'indices': ((0, 0, 0), (0, 0, 2), (0, 0, 4), (0, 0, 6), + (0, 1, 0), (0, 1, 2), (0, 1, 4), (0, 1, 6), + (1, 0, 0), (1, 0, 2), (1, 0, 4), (1, 0, 6)), + 'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + 'dense_shape': (2, 2, 8)}, + 'expected_dense_tensor': [ + [[[0., 0.], [1., 0.]], [[2., 0.], [3., 0.]], + [[4., 0.], [5., 0.]], [[6., 0.], [7., 0.]]], + [[[10., 0.], [11., 0.]], [[12., 0.], [13., 0.]], + [[0., 0.], [0., 0.]], [[0., 0.], [0., 0.]]]]}, + ) + def test_get_dense_tensor_multi_dim( + self, sparse_input_args, expected_dense_tensor): """Tests get_sequence_dense_tensor for multi-dim numeric_column.""" - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] - # example 1, [[[10., 11.], [12., 13.]]] - indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), - (1, 0), (1, 1), (1, 2), (1, 3)), - values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), - dense_shape=(2, 8)) - expected_dense_tensor = [ - [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]], - [[[10., 11.], [12., 13.]], [[0., 0.], [0., 0.]]], - ] + sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) dense_tensor, _ = numeric_column._get_sequence_dense_tensor( @@ -1070,43 +1425,56 @@ class SequenceNumericColumnTest(test.TestCase): self.assertAllEqual( expected_dense_tensor, dense_tensor.eval(session=sess)) - def test_sequence_length(self): - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[0., 1., 2.], [3., 4., 5.]] - # example 1, [[10., 11., 12.]] - indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), - (1, 0), (1, 1), (1, 2)), - values=(0., 1., 2., 3., 4., 5., 10., 11., 12.), - dense_shape=(2, 6)) - expected_sequence_length = [2, 1] - numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,)) + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (2., 0., 1.), + 'dense_shape': (2, 2)}, + 'expected_sequence_length': [1, 2], + 'shape': (1,)}, + {'testcase_name': '3D', + 'inputs_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), + 'values': (2., 0., 1., 2.), + 'dense_shape': (2, 2, 2)}, + 'expected_sequence_length': [1, 2], + 'shape': (1,)}, + {'testcase_name': '2D_with_shape', + 'inputs_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (2., 0., 1.), + 'dense_shape': (2, 2)}, + 'expected_sequence_length': [1, 1], + 'shape': (2,)}, + {'testcase_name': '3D_with_shape', + 'inputs_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), + 'values': (2., 0., 1., 2.), + 'dense_shape': (2, 2, 2)}, + 'expected_sequence_length': [1, 2], + 'shape': (2,)}, + ) + def test_sequence_length(self, inputs_args, expected_sequence_length, shape): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + numeric_column = sfc.sequence_numeric_column('aaa', shape=shape) _, sequence_length = numeric_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _LazyBuilder({'aaa': inputs})) with monitored_session.MonitoredSession() as sess: sequence_length = sess.run(sequence_length) self.assertAllEqual(expected_sequence_length, sequence_length) self.assertEqual(np.int64, sequence_length.dtype) - def test_sequence_length_with_shape(self): - """Tests _sequence_length with shape !=(1,).""" - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[0.], [1]] - # example 1, [[10.]] - indices=((0, 0), (0, 1), (1, 0)), - values=(0., 1., 10.), - dense_shape=(2, 2)) - expected_sequence_length = [2, 1] - numeric_column = sfc.sequence_numeric_column('aaa') - - _, sequence_length = numeric_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) - - with monitored_session.MonitoredSession() as sess: - self.assertAllEqual( - expected_sequence_length, sequence_length.eval(session=sess)) - def test_sequence_length_with_empty_rows(self): """Tests _sequence_length when some examples do not have ids.""" sparse_input = sparse_tensor.SparseTensorValue( diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD index 490da9b33b6393ca4336573e9592160a1eaf5e01..57a5bfbf43c915775c6b0ef05baac19581213a09 100644 --- a/tensorflow/contrib/fused_conv/BUILD +++ b/tensorflow/contrib/fused_conv/BUILD @@ -145,6 +145,7 @@ cuda_py_test( "//tensorflow/python:client_testlib", ], tags = [ + "manual", # TODO(b/117128481): re-enable after fixing OSS build "no_pip", "requires-gpu-sm70", ], @@ -169,6 +170,7 @@ cuda_py_test( ], main = "python/ops/fused_conv2d_bias_activation_benchmark.py", tags = [ + "manual", # TODO(b/117128481): re-enable after fixing OSS build "requires-gpu-sm70", ], ) diff --git a/tensorflow/contrib/gan/README.md b/tensorflow/contrib/gan/README.md index 4ead66ca13e74bacc0e4679a8d5c4e0f23d04b69..9ab86329eaf0e6fd426aef1f552f4e27c2ad65de 100644 --- a/tensorflow/contrib/gan/README.md +++ b/tensorflow/contrib/gan/README.md @@ -34,28 +34,32 @@ mix TFGAN, native TF, and other custom frameworks TFGAN is composed of several parts which were design to exist independently. These include the following main pieces (explained in detail below). -* [core](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/train.py): -provides the main infrastructure needed to train a GAN. Training occurs in four phases, and each phase -can be completed by custom-code or by using a TFGAN library call. - -* [features](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/features/python/): -Many common GAN operations and normalization techniques are implemented for you -to use, such as instance normalization and conditioning. - -* [losses](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/losses/python/): -Easily experiment with already-implemented and well-tested losses and penalties, -such as the Wasserstein loss, gradient penalty, mutual information penalty, etc - -* [evaluation](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/): -Use `Inception Score` or `Frechet Distance` with a pretrained Inception -network to evaluate your unconditional generative model. You can also use -your own pretrained classifier for more specific performance numbers, or use -other methods for evaluating conditional generative models. - -* [examples](https://github.com/tensorflow/models/tree/master/research/gan/) and [tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb): -See examples of how to use TFGAN to make GAN training easier, or use the more complicated examples to jumpstart your -own project. These include unconditional and conditional GANs, InfoGANs, -adversarial losses on existing networks, and image-to-image translation. +* [core](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/train.py): + provides the main infrastructure needed to train a GAN. Training occurs in + four phases, and each phase can be completed by custom-code or by using a + TFGAN library call. + +* [features](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/features/python/): + Many common GAN operations and normalization techniques are implemented for + you to use, such as instance normalization and conditioning. + +* [losses](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/losses/python/): + Easily experiment with already-implemented and well-tested losses and + penalties, such as the Wasserstein loss, gradient penalty, mutual + information penalty, etc + +* [evaluation](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/): + Use `Inception Score`, `Frechet Distance`, or `Kernel Distance` with a + pretrained Inception network to evaluate your unconditional generative + model. You can also use your own pretrained classifier for more specific + performance numbers, or use other methods for evaluating conditional + generative models. + +* [examples](https://github.com/tensorflow/models/tree/master/research/gan/) + and [tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb): See examples of how to use TFGAN to make + GAN training easier, or use the more complicated examples to jumpstart your + own project. These include unconditional and conditional GANs, InfoGANs, + adversarial losses on existing networks, and image-to-image translation. ## Training a GAN model diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index 7243f150ce540cc96d1960511bc1500b7f917791..219cc199d79eca8c263859ae46bbb1ce0b4442b3 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -112,7 +112,8 @@ class GANEstimator(estimator.Estimator): get_eval_metric_ops_fn=None, add_summaries=None, use_loss_summaries=True, - config=None): + config=None, + warm_start_from=None): """Initializes a GANEstimator instance. Args: @@ -151,6 +152,8 @@ class GANEstimator(estimator.Estimator): use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. config: `RunConfig` object to configure the runtime settings. + warm_start_from: A filepath to a checkpoint or saved model, or a + WarmStartSettings object to configure initialization. Raises: ValueError: If loss functions aren't callable. @@ -187,7 +190,8 @@ class GANEstimator(estimator.Estimator): get_hooks_fn, use_loss_summaries) super(GANEstimator, self).__init__( - model_fn=_model_fn, model_dir=model_dir, config=config) + model_fn=_model_fn, model_dir=model_dir, config=config, + warm_start_from=warm_start_from) def _get_gan_model( diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index 83f8dd641fa9a641533161373c29c5d2f81746a1..cfc867f0831986ef517f14fee0ed9d4773bb5cb6 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -33,9 +33,11 @@ from tensorflow.contrib.learn.python.learn.learn_io import graph_io from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.estimator import WarmStartSettings from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework.errors_impl import NotFoundError from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib @@ -317,5 +319,71 @@ class GANEstimatorIntegrationTest(test.TestCase): prediction_size=[batch_size, input_dim]) +class GANEstimatorWarmStartTest(test.TestCase): + + def setUp(self): + self._model_dir = self.get_temp_dir() + self.new_variable_name = 'new_var' + self.new_variable_value = [1, 2, 3] + + def tearDown(self): + writer_cache.FileWriterCache.clear() + + def _test_warm_start(self, warm_start_from=None): + """Tests whether WarmStartSettings work as intended.""" + def generator_with_new_variable(noise_dict, mode): + variable_scope.get_variable(name=self.new_variable_name, + initializer=self.new_variable_value, + trainable=True) + return generator_fn(noise_dict, mode) + + def train_input_fn(): + data = np.zeros([3, 4]) + return {'x': data}, data + + est = estimator.GANEstimator( + generator_fn=generator_fn, + discriminator_fn=discriminator_fn, + generator_loss_fn=losses.wasserstein_generator_loss, + discriminator_loss_fn=losses.wasserstein_discriminator_loss, + generator_optimizer=training.GradientDescentOptimizer(1.0), + discriminator_optimizer=training.GradientDescentOptimizer(1.0), + model_dir=self._model_dir) + + est.train(train_input_fn, steps=1) + + est_warm = estimator.GANEstimator( + generator_fn=generator_with_new_variable, + discriminator_fn=discriminator_fn, + generator_loss_fn=losses.wasserstein_generator_loss, + discriminator_loss_fn=losses.wasserstein_discriminator_loss, + generator_optimizer=training.GradientDescentOptimizer(1.0), + discriminator_optimizer=training.GradientDescentOptimizer(1.0), + model_dir=None if warm_start_from else self._model_dir, + warm_start_from=warm_start_from) + + est_warm.train(train_input_fn, steps=1) + + return est_warm + + def test_warm_start_error(self): + """Test if exception when reloading different estimators.""" + with self.assertRaises(NotFoundError): + self._test_warm_start() + + def test_warm_start_success(self): + """Test if GANEstimator allows explicit warm start variable assignment.""" + # Regex matches all variable names in ckpt except for new_var. + var_regex = '^(?!.*%s.*)' % self.new_variable_name + warmstart = WarmStartSettings(ckpt_to_initialize_from=self._model_dir, + vars_to_warm_start=var_regex) + est_warm = self._test_warm_start(warm_start_from=warmstart) + full_variable_name = 'Generator/%s' % self.new_variable_name + self.assertIn(full_variable_name, est_warm.get_variable_names()) + equal_vals = np.array_equal(est_warm.get_variable_value(full_variable_name), + self.new_variable_value) + self.assertTrue(equal_vals) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index d914f549457a1e893ed43a3b8bc1ae5be7bb4303..a71ee53311c1c057a5b41be0331bf56ce1a82f74 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -14,8 +14,8 @@ # ============================================================================== """Model evaluation tools for TFGAN. -These methods come from https://arxiv.org/abs/1606.03498 and -https://arxiv.org/abs/1706.08500. +These methods come from https://arxiv.org/abs/1606.03498, +https://arxiv.org/abs/1706.08500, and https://arxiv.org/abs/1801.01401. NOTE: This implementation uses the same weights as in https://github.com/openai/improved-gan/blob/master/inception_score/model.py, @@ -40,6 +40,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import image_ops from tensorflow.python.ops import linalg_ops @@ -64,6 +65,12 @@ __all__ = [ 'frechet_classifier_distance_from_activations', 'mean_only_frechet_classifier_distance_from_activations', 'diagonal_only_frechet_classifier_distance_from_activations', + 'kernel_inception_distance', + 'kernel_inception_distance_and_std', + 'kernel_classifier_distance', + 'kernel_classifier_distance_and_std', + 'kernel_classifier_distance_from_activations', + 'kernel_classifier_distance_and_std_from_activations', 'INCEPTION_DEFAULT_IMAGE_SIZE', ] @@ -734,3 +741,373 @@ frechet_inception_distance = functools.partial( frechet_classifier_distance, classifier_fn=functools.partial( run_inception, output_tensor=INCEPTION_FINAL_POOL)) + + +def kernel_classifier_distance(real_images, + generated_images, + classifier_fn, + num_classifier_batches=1, + max_block_size=1024, + dtype=None): + """Kernel "classifier" distance for evaluating a generative model. + + This is based on the Kernel Inception distance, but for an arbitrary + embedding. + + This technique is described in detail in https://arxiv.org/abs/1801.01401. + Given two distributions P and Q of activations, this function calculates + + E_{X, X' ~ P}[k(X, X')] + E_{Y, Y' ~ Q}[k(Y, Y')] + - 2 E_{X ~ P, Y ~ Q}[k(X, Y)] + + where k is the polynomial kernel + + k(x, y) = ( x^T y / dimension + 1 )^3. + + This captures how different the distributions of real and generated images' + visual features are. Like the Frechet distance (and unlike the Inception + score), this is a true distance and incorporates information about the + target images. Unlike the Frechet score, this function computes an + *unbiased* and asymptotically normal estimator, which makes comparing + estimates across models much more intuitive. + + The estimator used takes time quadratic in max_block_size. Larger values of + max_block_size will decrease the variance of the estimator but increase the + computational cost. This differs slightly from the estimator used by the + original paper; it is the block estimator of https://arxiv.org/abs/1307.1954. + + NOTE: the blocking code assumes that real_activations and + generated_activations are both in random order. If either is sorted in a + meaningful order, the estimator will behave poorly. + + NOTE: This function consumes images, computes their activations, and then + computes the classifier score. If you would like to precompute many + activations for real and generated images for large batches, or to compute + multiple scores based on the same images, please use + kernel_clasifier_distance_from_activations(), which this method also uses. + + Args: + real_images: Real images to use to compute Kernel Inception distance. + generated_images: Generated images to use to compute Kernel Inception + distance. + classifier_fn: A function that takes images and produces activations based + on a classifier. + num_classifier_batches: Number of batches to split images in to in order to + efficiently run them through the classifier network. + max_estimator_block_size: integer, default 1024. The distance estimator + splits samples into blocks for computational efficiency. Larger values are + more computationally expensive but decrease the variance of the distance + estimate. + dtype: if not None, coerce activations to this dtype before computations. + + Returns: + The Kernel Inception Distance. A floating-point scalar of the same type + as the output of the activations. + """ + return kernel_classifier_distance_and_std( + real_images, + generated_images, + classifier_fn, + num_classifier_batches=num_classifier_batches, + max_block_size=max_block_size, + dtype=dtype)[0] + + +kernel_inception_distance = functools.partial( + kernel_classifier_distance, + classifier_fn=functools.partial( + run_inception, output_tensor=INCEPTION_FINAL_POOL)) + + +def kernel_classifier_distance_and_std(real_images, + generated_images, + classifier_fn, + num_classifier_batches=1, + max_block_size=1024, + dtype=None): + """Kernel "classifier" distance for evaluating a generative model. + + This is based on the Kernel Inception distance, but for an arbitrary + embedding. Also returns an estimate of the standard error of the distance + estimator. + + This technique is described in detail in https://arxiv.org/abs/1801.01401. + Given two distributions P and Q of activations, this function calculates + + E_{X, X' ~ P}[k(X, X')] + E_{Y, Y' ~ Q}[k(Y, Y')] + - 2 E_{X ~ P, Y ~ Q}[k(X, Y)] + + where k is the polynomial kernel + + k(x, y) = ( x^T y / dimension + 1 )^3. + + This captures how different the distributions of real and generated images' + visual features are. Like the Frechet distance (and unlike the Inception + score), this is a true distance and incorporates information about the + target images. Unlike the Frechet score, this function computes an + *unbiased* and asymptotically normal estimator, which makes comparing + estimates across models much more intuitive. + + The estimator used takes time quadratic in max_block_size. Larger values of + max_block_size will decrease the variance of the estimator but increase the + computational cost. This differs slightly from the estimator used by the + original paper; it is the block estimator of https://arxiv.org/abs/1307.1954. + + NOTE: the blocking code assumes that real_activations and + generated_activations are both in random order. If either is sorted in a + meaningful order, the estimator will behave poorly. + + NOTE: This function consumes images, computes their activations, and then + computes the classifier score. If you would like to precompute many + activations for real and generated images for large batches, or to compute + multiple scores based on the same images, please use + kernel_clasifier_distance_from_activations(), which this method also uses. + + Args: + real_images: Real images to use to compute Kernel Inception distance. + generated_images: Generated images to use to compute Kernel Inception + distance. + classifier_fn: A function that takes images and produces activations based + on a classifier. + num_classifier_batches: Number of batches to split images in to in order to + efficiently run them through the classifier network. + max_estimator_block_size: integer, default 1024. The distance estimator + splits samples into blocks for computational efficiency. Larger values are + more computationally expensive but decrease the variance of the distance + estimate. Having a smaller block size also gives a better estimate of the + standard error. + dtype: if not None, coerce activations to this dtype before computations. + + Returns: + The Kernel Inception Distance. A floating-point scalar of the same type + as the output of the activations. + An estimate of the standard error of the distance estimator (a scalar of + the same type). + """ + real_images_list = array_ops.split( + real_images, num_or_size_splits=num_classifier_batches) + generated_images_list = array_ops.split( + generated_images, num_or_size_splits=num_classifier_batches) + + real_imgs = array_ops.stack(real_images_list) + generated_imgs = array_ops.stack(generated_images_list) + + # Compute the activations using the memory-efficient `map_fn`. + def compute_activations(elems): + return functional_ops.map_fn( + fn=classifier_fn, + elems=elems, + parallel_iterations=1, + back_prop=False, + swap_memory=True, + name='RunClassifier') + + real_a = compute_activations(real_imgs) + gen_a = compute_activations(generated_imgs) + + # Ensure the activations have the right shapes. + real_a = array_ops.concat(array_ops.unstack(real_a), 0) + gen_a = array_ops.concat(array_ops.unstack(gen_a), 0) + + return kernel_classifier_distance_and_std_from_activations( + real_a, gen_a, max_block_size=max_block_size) + + +kernel_inception_distance_and_std = functools.partial( + kernel_classifier_distance_and_std, + classifier_fn=functools.partial( + run_inception, output_tensor=INCEPTION_FINAL_POOL)) + + +def kernel_classifier_distance_from_activations(real_activations, + generated_activations, + max_block_size=1024, + dtype=None): + """Kernel "classifier" distance for evaluating a generative model. + + This methods computes the kernel classifier distance from activations of + real images and generated images. This can be used independently of the + kernel_classifier_distance() method, especially in the case of using large + batches during evaluation where we would like to precompute all of the + activations before computing the classifier distance, or if we want to + compute multiple metrics based on the same images. + + This technique is described in detail in https://arxiv.org/abs/1801.01401. + Given two distributions P and Q of activations, this function calculates + + E_{X, X' ~ P}[k(X, X')] + E_{Y, Y' ~ Q}[k(Y, Y')] + - 2 E_{X ~ P, Y ~ Q}[k(X, Y)] + + where k is the polynomial kernel + + k(x, y) = ( x^T y / dimension + 1 )^3. + + This captures how different the distributions of real and generated images' + visual features are. Like the Frechet distance (and unlike the Inception + score), this is a true distance and incorporates information about the + target images. Unlike the Frechet score, this function computes an + *unbiased* and asymptotically normal estimator, which makes comparing + estimates across models much more intuitive. + + The estimator used takes time quadratic in max_block_size. Larger values of + max_block_size will decrease the variance of the estimator but increase the + computational cost. This differs slightly from the estimator used by the + original paper; it is the block estimator of https://arxiv.org/abs/1307.1954. + + NOTE: the blocking code assumes that real_activations and + generated_activations are both in random order. If either is sorted in a + meaningful order, the estimator will behave poorly. + + Args: + real_activations: 2D Tensor containing activations of real data. Shape is + [batch_size, activation_size]. + generated_activations: 2D Tensor containing activations of generated data. + Shape is [batch_size, activation_size]. + max_block_size: integer, default 1024. The distance estimator splits samples + into blocks for computational efficiency. Larger values are more + computationally expensive but decrease the variance of the distance + estimate. + dtype: if not None, coerce activations to this dtype before computations. + + Returns: + The Kernel Inception Distance. A floating-point scalar of the same type + as the output of the activations. + """ + return kernel_classifier_distance_and_std_from_activations( + real_activations, generated_activations, max_block_size=max_block_size)[0] + + +def kernel_classifier_distance_and_std_from_activations(real_activations, + generated_activations, + max_block_size=1024, + dtype=None): + """Kernel "classifier" distance for evaluating a generative model. + + This methods computes the kernel classifier distance from activations of + real images and generated images. This can be used independently of the + kernel_classifier_distance() method, especially in the case of using large + batches during evaluation where we would like to precompute all of the + activations before computing the classifier distance, or if we want to + compute multiple metrics based on the same images. It also returns a rough + estimate of the standard error of the estimator. + + This technique is described in detail in https://arxiv.org/abs/1801.01401. + Given two distributions P and Q of activations, this function calculates + + E_{X, X' ~ P}[k(X, X')] + E_{Y, Y' ~ Q}[k(Y, Y')] + - 2 E_{X ~ P, Y ~ Q}[k(X, Y)] + + where k is the polynomial kernel + + k(x, y) = ( x^T y / dimension + 1 )^3. + + This captures how different the distributions of real and generated images' + visual features are. Like the Frechet distance (and unlike the Inception + score), this is a true distance and incorporates information about the + target images. Unlike the Frechet score, this function computes an + *unbiased* and asymptotically normal estimator, which makes comparing + estimates across models much more intuitive. + + The estimator used takes time quadratic in max_block_size. Larger values of + max_block_size will decrease the variance of the estimator but increase the + computational cost. This differs slightly from the estimator used by the + original paper; it is the block estimator of https://arxiv.org/abs/1307.1954. + The estimate of the standard error will also be more reliable when there are + more blocks, i.e. when max_block_size is smaller. + + NOTE: the blocking code assumes that real_activations and + generated_activations are both in random order. If either is sorted in a + meaningful order, the estimator will behave poorly. + + Args: + real_activations: 2D Tensor containing activations of real data. Shape is + [batch_size, activation_size]. + generated_activations: 2D Tensor containing activations of generated data. + Shape is [batch_size, activation_size]. + max_block_size: integer, default 1024. The distance estimator splits samples + into blocks for computational efficiency. Larger values are more + computationally expensive but decrease the variance of the distance + estimate. Having a smaller block size also gives a better estimate of the + standard error. + dtype: if not None, coerce activations to this dtype before computations. + + Returns: + The Kernel Inception Distance. A floating-point scalar of the same type + as the output of the activations. + An estimate of the standard error of the distance estimator (a scalar of + the same type). + """ + + real_activations.shape.assert_has_rank(2) + generated_activations.shape.assert_has_rank(2) + real_activations.shape[1].assert_is_compatible_with( + generated_activations.shape[1]) + + if dtype is None: + dtype = real_activations.dtype + assert generated_activations.dtype == dtype + else: + real_activations = math_ops.cast(real_activations, dtype) + generated_activations = math_ops.cast(generated_activations, dtype) + + # Figure out how to split the activations into blocks of approximately + # equal size, with none larger than max_block_size. + n_r = array_ops.shape(real_activations)[0] + n_g = array_ops.shape(generated_activations)[0] + + n_bigger = math_ops.maximum(n_r, n_g) + n_blocks = math_ops.to_int32(math_ops.ceil(n_bigger / max_block_size)) + + v_r = n_r // n_blocks + v_g = n_g // n_blocks + + n_plusone_r = n_r - v_r * n_blocks + n_plusone_g = n_g - v_g * n_blocks + + sizes_r = array_ops.concat([ + array_ops.fill([n_blocks - n_plusone_r], v_r), + array_ops.fill([n_plusone_r], v_r + 1), + ], 0) + sizes_g = array_ops.concat([ + array_ops.fill([n_blocks - n_plusone_g], v_g), + array_ops.fill([n_plusone_g], v_g + 1), + ], 0) + + zero = array_ops.zeros([1], dtype=dtypes.int32) + inds_r = array_ops.concat([zero, math_ops.cumsum(sizes_r)], 0) + inds_g = array_ops.concat([zero, math_ops.cumsum(sizes_g)], 0) + + dim = math_ops.cast(real_activations.shape[1], dtype) + + def compute_kid_block(i): + 'Compute the ith block of the KID estimate.' + r_s = inds_r[i] + r_e = inds_r[i + 1] + r = real_activations[r_s:r_e] + m = math_ops.cast(r_e - r_s, dtype) + + g_s = inds_g[i] + g_e = inds_g[i + 1] + g = generated_activations[g_s:g_e] + n = math_ops.cast(g_e - g_s, dtype) + + k_rr = (math_ops.matmul(r, r, transpose_b=True) / dim + 1)**3 + k_rg = (math_ops.matmul(r, g, transpose_b=True) / dim + 1)**3 + k_gg = (math_ops.matmul(g, g, transpose_b=True) / dim + 1)**3 + return (-2 * math_ops.reduce_mean(k_rg) + + (math_ops.reduce_sum(k_rr) - math_ops.trace(k_rr)) / (m * (m - 1)) + + (math_ops.reduce_sum(k_gg) - math_ops.trace(k_gg)) / (n * (n - 1))) + + ests = functional_ops.map_fn( + compute_kid_block, math_ops.range(n_blocks), dtype=dtype, back_prop=False) + + mn = math_ops.reduce_mean(ests) + + # nn_impl.moments doesn't use the Bessel correction, which we want here + n_blocks_ = math_ops.cast(n_blocks, dtype) + var = control_flow_ops.cond( + math_ops.less_equal(n_blocks, 1), + lambda: array_ops.constant(float('nan'), dtype=dtype), + lambda: math_ops.reduce_sum(math_ops.square(ests - mn)) / (n_blocks_ - 1)) + + return mn, math_ops.sqrt(var / n_blocks_) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py index d64dfd1576578435d0e3bd4e338fe2e9e4a6f6ab..dbff1d2a367e10adc607dafb4c571bb3607a3963 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py @@ -86,6 +86,42 @@ def _expected_fid(real_imgs, gen_imgs): def _expected_trace_sqrt_product(sigma, sigma_v): return np.trace(scp_linalg.sqrtm(np.dot(sigma, sigma_v))) + +def _expected_kid_and_std(real_imgs, gen_imgs, max_block_size=1024): + n_r, dim = real_imgs.shape + n_g = gen_imgs.shape[0] + + n_blocks = int(np.ceil(max(n_r, n_g) / max_block_size)) + + sizes_r = np.full(n_blocks, n_r // n_blocks) + to_patch = n_r - n_blocks * (n_r // n_blocks) + if to_patch > 0: + sizes_r[-to_patch:] += 1 + inds_r = np.r_[0, np.cumsum(sizes_r)] + assert inds_r[-1] == n_r + + sizes_g = np.full(n_blocks, n_g // n_blocks) + to_patch = n_g - n_blocks * (n_g // n_blocks) + if to_patch > 0: + sizes_g[-to_patch:] += 1 + inds_g = np.r_[0, np.cumsum(sizes_g)] + assert inds_g[-1] == n_g + + ests = [] + for i in range(n_blocks): + r = real_imgs[inds_r[i]:inds_r[i + 1]] + g = gen_imgs[inds_g[i]:inds_g[i + 1]] + + k_rr = (np.dot(r, r.T) / dim + 1)**3 + k_rg = (np.dot(r, g.T) / dim + 1)**3 + k_gg = (np.dot(g, g.T) / dim + 1)**3 + ests.append(-2 * k_rg.mean() + + k_rr[np.triu_indices_from(k_rr, k=1)].mean() + + k_gg[np.triu_indices_from(k_gg, k=1)].mean()) + + var = np.var(ests, ddof=1) if len(ests) > 1 else np.nan + return np.mean(ests), np.sqrt(var / len(ests)) + # A dummy GraphDef string with the minimum number of Ops. graphdef_string = """ node { @@ -272,6 +308,18 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): # Check that none of the model variables are trainable. self.assertListEqual([], variables.trainable_variables()) + def test_kernel_inception_distance_graph(self): + """Test `frechet_inception_distance` graph construction.""" + img = array_ops.ones([7, 299, 299, 3]) + distance = _run_with_mock(classifier_metrics.kernel_inception_distance, img, + img) + + self.assertTrue(isinstance(distance, ops.Tensor)) + distance.shape.assert_has_rank(0) + + # Check that none of the model variables are trainable. + self.assertListEqual([], variables.trainable_variables()) + def test_run_inception_multicall(self): """Test that `run_inception` can be called multiple times.""" for batch_size in (7, 3, 2): @@ -411,6 +459,56 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): # Check that the FIDs increase monotonically. self.assertTrue(all(fid_a < fid_b for fid_a, fid_b in zip(fids, fids[1:]))) + def test_kernel_classifier_distance_value(self): + """Test that `kernel_classifier_distance` gives the correct value.""" + np.random.seed(0) + + test_pool_real_a = np.float32(np.random.randn(512, 256)) + test_pool_gen_a = np.float32(np.random.randn(512, 256) * 1.1 + .05) + + kid_op = _run_with_mock( + classifier_metrics.kernel_classifier_distance_and_std, + test_pool_real_a, + test_pool_gen_a, + classifier_fn=lambda x: x, + max_block_size=600) + + with self.test_session() as sess: + actual_kid, actual_std = sess.run(kid_op) + + expected_kid, expected_std = _expected_kid_and_std(test_pool_real_a, + test_pool_gen_a) + + self.assertAllClose(expected_kid, actual_kid, 0.001) + self.assertAllClose(expected_std, actual_std, 0.001) + + def test_kernel_classifier_distance_block_sizes(self): + """Test that `kernel_classifier_distance` works with unusual max_block_size + + values.. + """ + np.random.seed(0) + + test_pool_real_a = np.float32(np.random.randn(512, 256)) + test_pool_gen_a = np.float32(np.random.randn(768, 256) * 1.1 + .05) + + max_block_size = array_ops.placeholder(dtypes.int32, shape=()) + kid_op = _run_with_mock( + classifier_metrics.kernel_classifier_distance_and_std_from_activations, + array_ops.constant(test_pool_real_a), + array_ops.constant(test_pool_gen_a), + max_block_size=max_block_size) + + for block_size in [50, 512, 1000]: + with self.test_session() as sess: + actual_kid, actual_std = sess.run(kid_op, {max_block_size: block_size}) + + expected_kid, expected_std = _expected_kid_and_std( + test_pool_real_a, test_pool_gen_a, max_block_size=block_size) + + self.assertAllClose(expected_kid, actual_kid, 0.001) + self.assertAllClose(expected_std, actual_std, 0.001) + def test_trace_sqrt_product_value(self): """Test that `trace_sqrt_product` gives the correct value.""" np.random.seed(0) diff --git a/tensorflow/contrib/gan/python/features/python/clip_weights_test.py b/tensorflow/contrib/gan/python/features/python/clip_weights_test.py index 2b7bb5f14e7f3d1b3f913d3426efaaae19079ffb..e4fac1976d605f1942947a747043d5c8b00392c1 100644 --- a/tensorflow/contrib/gan/python/features/python/clip_weights_test.py +++ b/tensorflow/contrib/gan/python/features/python/clip_weights_test.py @@ -47,13 +47,13 @@ class ClipWeightsTest(test.TestCase): train_op1 = opt.minimize(loss, var_list=self.variables) train_op2 = opt_clip.minimize(loss, var_list=self.variables) - with self.test_session(use_gpu=True) as sess: + with self.cached_session(use_gpu=True) as sess: sess.run(variables.global_variables_initializer()) self.assertEqual(2.0, self.variables[0].eval()) sess.run(train_op1) self.assertLess(0.1, self.variables[0].eval()) - with self.test_session(use_gpu=True) as sess: + with self.cached_session(use_gpu=True) as sess: sess.run(variables.global_variables_initializer()) self.assertEqual(2.0, self.variables[0].eval()) sess.run(train_op2) diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py index 08584dcd656e3e7a079a3fa36f44742b5eac1178..3c9dfd6de024b1558bed2e3678606fef8bb4d677 100644 --- a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py @@ -37,7 +37,7 @@ class TensorPoolTest(test.TestCase): output_value = tensor_pool(input_value, pool_size=10) self.assertEqual(output_value.shape.as_list(), [None, None, 3]) - with self.test_session(use_gpu=True) as session: + with self.session(use_gpu=True) as session: for i in range(10): session.run(output_value, {input_value: [[[i] * 3]]}) session.run(output_value, {input_value: [[[i] * 3] * 2]}) @@ -49,7 +49,7 @@ class TensorPoolTest(test.TestCase): output_value = tensor_pool(input_value, pool_size=10) self.assertEqual(output_value.shape.as_list(), []) - with self.test_session(use_gpu=True) as session: + with self.session(use_gpu=True) as session: outs = [] for i in range(50): out = session.run(output_value, {input_value: i}) @@ -67,7 +67,7 @@ class TensorPoolTest(test.TestCase): input_value, pool_size=10, pooling_probability=0.0) self.assertEqual(output_value.shape.as_list(), []) - with self.test_session(use_gpu=True) as session: + with self.session(use_gpu=True) as session: for i in range(50): out = session.run(output_value, {input_value: i}) self.assertEqual(out, i) @@ -83,7 +83,7 @@ class TensorPoolTest(test.TestCase): pooling_probability=pooling_probability) self.assertEqual(output_value.shape.as_list(), []) - with self.test_session(use_gpu=True) as session: + with self.session(use_gpu=True) as session: not_pooled = 0 total = 1000 for i in range(total): @@ -104,7 +104,7 @@ class TensorPoolTest(test.TestCase): for output_value in output_values: self.assertEqual(output_value.shape.as_list(), []) - with self.test_session(use_gpu=True) as session: + with self.session(use_gpu=True) as session: for i in range(10): outs = session.run(output_values, { input_values[0]: i, diff --git a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py index 2fe06a287284ff994326d5a977a2e4d4634268ae..ecfbb8a432e3308863edd6f1343be55c1fe5753c 100644 --- a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py +++ b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py @@ -59,7 +59,7 @@ class VirtualBatchnormTest(test.TestCase): mom_mean, mom_var = nn.moments(tensors, axes) vb_var = mean_sq - math_ops.square(vb_mean) - with self.test_session(use_gpu=True) as sess: + with self.cached_session(use_gpu=True) as sess: vb_mean_np, vb_var_np, mom_mean_np, mom_var_np = sess.run([ vb_mean, vb_var, mom_mean, mom_var]) @@ -93,7 +93,7 @@ class VirtualBatchnormTest(test.TestCase): vb_mean = array_ops.squeeze(vb_mean, batch_axis) vb_variance = array_ops.squeeze(vb_variance, batch_axis) - with self.test_session(use_gpu=True) as sess: + with self.cached_session(use_gpu=True) as sess: vb_mean_np, vb_var_np, mom_mean_np, mom_var_np = sess.run([ vb_mean, vb_variance, mom_mean, mom_variance]) @@ -116,7 +116,7 @@ class VirtualBatchnormTest(test.TestCase): vbn = virtual_batchnorm.VBN(batch, axis, batch_axis=batch_axis) vbn_normalized = vbn.reference_batch_normalization() - with self.test_session(use_gpu=True) as sess: + with self.cached_session(use_gpu=True) as sess: variables_lib.global_variables_initializer().run() bn_normalized_np, vbn_normalized_np = sess.run( @@ -142,7 +142,7 @@ class VirtualBatchnormTest(test.TestCase): vb_normed = array_ops.squeeze( vbn(array_ops.expand_dims(examples[i], [0])), [0]) - with self.test_session(use_gpu=True) as sess: + with self.cached_session(use_gpu=True) as sess: variables_lib.global_variables_initializer().run() bn_np, vb_np = sess.run([batch_normalized, vb_normed]) self.assertAllClose(bn_np[i, ...], vb_np) @@ -167,7 +167,7 @@ class VirtualBatchnormTest(test.TestCase): vbn = virtual_batchnorm.VBN(reference_batch) vbn_fixed_example = array_ops.squeeze( vbn(array_ops.expand_dims(fixed_example, 0)), 0) - with self.test_session(use_gpu=True): + with self.session(use_gpu=True): variables_lib.global_variables_initializer().run() vbn_fixed_example_np = vbn_fixed_example.eval() @@ -180,7 +180,7 @@ class VirtualBatchnormTest(test.TestCase): minibatch = array_ops.stack([fixed_example] + examples) vbn_minibatch = vbn(minibatch) cur_vbn_fixed_example = vbn_minibatch[0, ...] - with self.test_session(use_gpu=True): + with self.cached_session(use_gpu=True): variables_lib.global_variables_initializer().run() cur_vbn_fixed_example_np = cur_vbn_fixed_example.eval() self.assertAllClose(vbn_fixed_example_np, cur_vbn_fixed_example_np) @@ -219,7 +219,7 @@ class VirtualBatchnormTest(test.TestCase): self.assertEqual(4, len(contrib_variables_lib.get_variables())) - with self.test_session(use_gpu=True) as sess: + with self.session(use_gpu=True) as sess: variables_lib.global_variables_initializer().run() sess.run(to_fetch) diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc index bb06f1c41c1d60f3c3b3639e3b32ea85161510b2..3549cedb70a6104ff3d3829d1b94cb5f08c5119c 100644 --- a/tensorflow/contrib/gdr/gdr_memory_manager.cc +++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -30,19 +29,17 @@ limitations under the License. #include #include "tensorflow/contrib/gdr/gdr.pb.h" -#include "tensorflow/core/common_runtime/bfc_allocator.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/dma_helper.h" -#include "tensorflow/core/common_runtime/pool_allocator.h" #include "tensorflow/core/common_runtime/process_state.h" #if GOOGLE_CUDA #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" #include "tensorflow/core/common_runtime/gpu/gpu_util.h" #endif // GOOGLE_CUDA -#include "tensorflow/core/framework/allocator_registry.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/numa.h" namespace tensorflow { @@ -70,14 +67,11 @@ bool IsGDRAvailable() { int TryToReadNumaNode(ibv_device* device) { #if defined(__APPLE__) LOG(INFO) << "OS X does not support NUMA - returning NUMA node 0"; - return 0; + return port::kNUMANoAffinity; #elif defined(PLATFORM_WINDOWS) // Windows support for NUMA is not currently implemented. Return node 0. - return 0; + return port::kNUMANoAffinity; #else - VLOG(2) << "Trying to read NUMA node for device: " << device->name; - static const int kUnknownNumaNode = -1; - auto filename = string(device->ibdev_path) + "/device/numa_node"; std::ifstream ifs(filename.c_str()); @@ -91,12 +85,12 @@ int TryToReadNumaNode(ibv_device* device) { << value << "), but there must be at least one NUMA node" ", so returning NUMA node zero"; - return 0; + return port::kNUMANoAffinity; } LOG(INFO) << "NUMA node for device: " << device->name << " is " << value; return value; } - return kUnknownNumaNode; + return port::kNUMANoAffinity; #endif } @@ -138,8 +132,6 @@ class GdrMemoryManager : public RemoteMemoryManager { Device* device, DeviceContext* device_context, bool on_host, StatusCallback done) override; - static void RegMemVisitors(); - protected: Status CreateEndpoint(const string& host, const string& port, RdmaEndpointPtr& endpoint); @@ -150,7 +142,8 @@ class GdrMemoryManager : public RemoteMemoryManager { ibv_mr* FindMemoryRegion(void* addr, size_t length); - void InsertMemoryRegion(void* addr, size_t length); + void InsertMemoryRegion(void* addr, size_t length, + const std::string& allocator_name); void EvictMemoryRegion(void* addr, size_t length); @@ -160,6 +153,7 @@ class GdrMemoryManager : public RemoteMemoryManager { RdmaEndpointPtr listening_; std::atomic stopped_; int epfd_; + int numa_node_; // Server side endpoints // Accessed sequentially in Run() so not protected by lock @@ -190,46 +184,10 @@ GdrMemoryManager::GdrMemoryManager(const string& host, const string& port) port_(port), listening_(nullptr, EndpointDeleter), stopped_(true), - next_key_(0) { - static std::once_flag flag; - std::call_once(flag, []() { RegMemVisitors(); }); -} + next_key_(0) {} GdrMemoryManager::~GdrMemoryManager() { close(epfd_); } -/*static*/ void GdrMemoryManager::RegMemVisitors() { - SubAllocator::Visitor alloc_visitor = [](void* ptr, int numa_node, - size_t num_bytes) { - GdrMemoryManager::Singleton().InsertMemoryRegion( - ptr, num_bytes, strings::StrCat("CPU:", numa_node)); - }; - SubAllocator::Visitor free_visitor = [](void* ptr, int numa_node, - size_t num_bytes) { - GdrMemoryManager::Singleton().EvictMemoryRegion(ptr, num_bytes); - }; - ProcessState::singleton()->AddCPUAllocVisitor(alloc_visitor); - ProcessState::singleton()->AddCPUFreeVisitor(free_visitor); - -#if GOOGLE_CUDA - if (IsGDRAvailable()) { - int32_t bus_id = TryToReadNumaNode(rdma_adapter_->context_->device) + 1; - - // Note we don't free allocated GPU memory so there is no free visitor - SubAllocator::Visitor cuda_alloc_visitor = [](void* ptr, int gpu_id, - size_t num_bytes) { - RdmaMemoryMgr::Singleton().InsertMemoryRegion( - ptr, num_bytes, strings::StrCat("GPU:", gpu_id)); - }; - GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id, - cuda_alloc_visitor); - GPUProcessState::singleton()->AddCUDAHostAllocVisitor(bus_id, - alloc_visitor); - GPUProcessState::singleton()->AddCUDAHostFreeVisitor(bus_id, free_visitor); - LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id; - } -#endif // GOOGLE_CUDA -} - Status GdrMemoryManager::Init() { epfd_ = epoll_create1(0); if (epfd_ == -1) { @@ -289,6 +247,42 @@ Status GdrMemoryManager::Init() { "cannot add server to epoll"); } + numa_node_ = TryToReadNumaNode(listening_->verbs->device); + + SubAllocator::Visitor alloc_visitor = [this](void* ptr, int numa_node, + size_t num_bytes) { + VLOG(2) << "Registering RDMA capable memory region on numa_node " + << numa_node; + InsertMemoryRegion(ptr, num_bytes, strings::StrCat("CPU:", numa_node)); + }; + SubAllocator::Visitor free_visitor = [this](void* ptr, int numa_node, + size_t num_bytes) { + VLOG(2) << "De-registering RDMA capable memory region on numa_node " + << numa_node; + EvictMemoryRegion(ptr, num_bytes); + }; + ProcessState::singleton()->AddCPUAllocVisitor(alloc_visitor); + ProcessState::singleton()->AddCPUFreeVisitor(free_visitor); + LOG(INFO) << "Instrumenting CPU allocator(s)"; + +#if GOOGLE_CUDA + if (IsGDRAvailable()) { + int bus_id = numa_node_ + 1; + + SubAllocator::Visitor cuda_alloc_visitor = [this](void* ptr, int gpu_id, + size_t num_bytes) { + VLOG(2) << "Registering RDMA capable memory region on GPU " << gpu_id; + InsertMemoryRegion(ptr, num_bytes, strings::StrCat("GPU:", gpu_id)); + }; + GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id, + cuda_alloc_visitor); + GPUProcessState::singleton()->AddCUDAHostAllocVisitor(bus_id, + alloc_visitor); + GPUProcessState::singleton()->AddCUDAHostFreeVisitor(bus_id, free_visitor); + LOG(INFO) << "Instrumenting GPU allocator(s) with bus_id " << bus_id; + } +#endif // GOOGLE_CUDA + return Status::OK(); } @@ -405,7 +399,7 @@ void GdrMemoryManager::TransportOptionsFromTensor( ibv_mr* mr = FindMemoryRegion(addr, length); #if GOOGLE_CUDA - if (!on_host) { + if (device->tensorflow_gpu_device_info() && !on_host) { Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0); Tensor* host_copy = new Tensor(alloc, tensor.dtype(), tensor.shape()); GPUUtil::CopyGPUTensorToCPU( @@ -456,11 +450,27 @@ void GdrMemoryManager::TransportOptionsFromTensor( #endif if (mr == nullptr) { - done(errors::Unavailable("Cannot find pinned memory region")); - return; + Allocator* alloc = ProcessState::singleton()->GetCPUAllocator(numa_node_); + Tensor host_copy(alloc, tensor.dtype(), tensor.shape()); + + std::memcpy(DMAHelper::buffer(&host_copy)->data(), buffer->data(), length); + VLOG(2) << "Copying " << length << " bytes unpinned tensor buffer"; + + buffer = DMAHelper::buffer(&host_copy); + addr = buffer->data(); + length = buffer->size(); + + mr = FindMemoryRegion(addr, length); + if (mr == nullptr) { + done(errors::Unavailable("Cannot find pinned memory region")); + return; + } + + buffer->Ref(); + } else { + buffer->Ref(); } - buffer->Ref(); TensorKey tensor_key = next_key_++; { mutex_lock l(server_mu_); @@ -470,7 +480,7 @@ void GdrMemoryManager::TransportOptionsFromTensor( uint64_t checksum = 0; if (VLOG_IS_ON(2)) { #ifdef GOOGLE_CUDA - if (!on_host) { + if (device->tensorflow_gpu_device_info() && !on_host) { checksum = GPUUtil::Checksum(device, device_context, tensor); } else { checksum = GPUUtil::Checksum(tensor); @@ -508,7 +518,8 @@ void GdrMemoryManager::TensorFromTransportOptions( Tensor host_copy; #if GOOGLE_CUDA if (mr == nullptr && !on_host) { - Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0); + Allocator* alloc = + GPUProcessState::singleton()->GetCUDAHostAllocator(numa_node_); host_copy = Tensor(alloc, tensor->dtype(), tensor->shape()); buffer = DMAHelper::buffer(&host_copy); addr = buffer->data(); @@ -518,8 +529,18 @@ void GdrMemoryManager::TensorFromTransportOptions( #endif // GOOGLE_CUDA if (mr == nullptr) { - done(errors::Unavailable("Cannot find pinned memory region")); - return; + Allocator* alloc = ProcessState::singleton()->GetCPUAllocator(numa_node_); + host_copy = Tensor(alloc, tensor->dtype(), tensor->shape()); + + buffer = DMAHelper::buffer(&host_copy); + addr = buffer->data(); + length = buffer->size(); + + mr = FindMemoryRegion(addr, length); + if (mr == nullptr) { + done(errors::Unavailable("Cannot find pinned memory region")); + return; + } } decltype(clients_)::iterator iter; @@ -568,7 +589,8 @@ void GdrMemoryManager::TensorFromTransportOptions( } #if GOOGLE_CUDA - if (host_copy.NumElements() > 0) { + if (device->tensorflow_gpu_device_info() && !on_host && + host_copy.NumElements() > 0) { uint64_t checksum = 0; if (VLOG_IS_ON(2)) { checksum = GPUUtil::Checksum(host_copy); @@ -598,6 +620,12 @@ void GdrMemoryManager::TensorFromTransportOptions( } #endif // GOOGLE_CUDA + if ((on_host || !device->tensorflow_gpu_device_info()) && + host_copy.NumElements() > 0) { + std::memcpy(DMAHelper::buffer(tensor)->data(), addr, length); + VLOG(2) << "Copying " << length << " bytes unpinned tensor buffer"; + } + uint64_t end = Env::Default()->NowMicros(); VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey() @@ -607,7 +635,7 @@ void GdrMemoryManager::TensorFromTransportOptions( uint64_t checksum = 0; if (VLOG_IS_ON(2)) { #ifdef GOOGLE_CUDA - if (device->tensorflow_gpu_device_info() && (!on_host)) { + if (device->tensorflow_gpu_device_info() && !on_host) { checksum = GPUUtil::Checksum(device, device_context, *tensor); } else { checksum = GPUUtil::Checksum(*tensor); @@ -668,7 +696,8 @@ ibv_mr* GdrMemoryManager::FindMemoryRegion(void* addr, size_t length) { } } -void GdrMemoryManager::InsertMemoryRegion(void* addr, size_t length) { +void GdrMemoryManager::InsertMemoryRegion(void* addr, size_t length, + const std::string& allocator_name) { if (length == 0) return; ibv_mr* mr = rdma_reg_read(listening_.get(), addr, length); if (mr != nullptr) { @@ -676,7 +705,8 @@ void GdrMemoryManager::InsertMemoryRegion(void* addr, size_t length) { auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator); mrs_.insert(iter, {mr, &MRDeleter}); } else { - LOG(WARNING) << "Cannot register memory region"; + LOG(WARNING) << "Cannot register memory region allocated by " + << allocator_name; } } diff --git a/tensorflow/contrib/hadoop/BUILD b/tensorflow/contrib/hadoop/BUILD index ccad31efa1dba92d954ff1cb455b6c9c784b29bc..178a8a6f08410bd9e5b61db47a3866ec6060a48c 100644 --- a/tensorflow/contrib/hadoop/BUILD +++ b/tensorflow/contrib/hadoop/BUILD @@ -7,12 +7,12 @@ exports_files(["LICENSE"]) load( "//tensorflow:tensorflow.bzl", "tf_custom_op_library", - "tf_custom_op_py_library", "tf_gen_op_libs", "tf_gen_op_wrapper_py", "tf_kernel_library", "tf_py_test", ) +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") filegroup( name = "test_data", diff --git a/tensorflow/contrib/ignite/BUILD b/tensorflow/contrib/ignite/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..9393b702d11a2ef84586f712d30c26fe2a8972bb --- /dev/null +++ b/tensorflow/contrib/ignite/BUILD @@ -0,0 +1,139 @@ +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load( + "//tensorflow:tensorflow.bzl", + "if_not_windows", + "if_windows", + "tf_custom_op_library", + "tf_custom_op_py_library", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", + "tf_kernel_library", + "tf_py_test", +) + +py_library( + name = "ignite", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_ops", + ], +) + +tf_custom_op_library( + name = "_dataset_ops.so", + srcs = ["ops/dataset_ops.cc"], + deps = [":dataset_kernels"], +) + +tf_gen_op_libs( + op_lib_names = ["dataset_ops"], +) + +cc_library( + name = "dataset_kernels", + srcs = [ + "kernels/ignite_dataset_ops.cc", + "kernels/ignite_client.h", + "kernels/ignite_byte_swapper.h", + "kernels/ignite_plain_client.h", + "kernels/ignite_ssl_wrapper.h", + "kernels/ignite_ssl_wrapper.cc", + "kernels/ignite_binary_object_parser.h", + "kernels/ignite_binary_object_parser.cc", + "kernels/ignite_dataset.h", + "kernels/ignite_dataset.cc", + "kernels/ignite_dataset_iterator.h", + "kernels/ignite_dataset_iterator.cc", + ] + if_not_windows([ + "kernels/ignite_plain_client_unix.cc", + ]) + if_windows([ + "kernels/ignite_plain_client_windows.cc", + ]), + copts = if_windows([ + "-DWIN32_LEAN_AND_MEAN", + ]), + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@boringssl//:ssl", + "@protobuf_archive//:protobuf_headers", + ], + alwayslink = 1, +) + +py_library( + name = "dataset_ops", + srcs = [ + "python/ops/ignite_dataset_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":ignite_op_loader", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + ], +) + +tf_gen_op_wrapper_py( + name = "gen_dataset_ops", + out = "python/ops/gen_dataset_ops.py", + deps = ["//tensorflow/contrib/ignite:dataset_ops_op_lib"], +) + +tf_kernel_library( + name = "dataset_ops_kernels", + deps = [ + ":dataset_kernels", + "//tensorflow/core:framework", + ], + alwayslink = 1, +) + +tf_custom_op_py_library( + name = "ignite_op_loader", + srcs = ["python/ops/ignite_op_loader.py"], + dso = ["//tensorflow/contrib/ignite:_dataset_ops.so"], + kernels = [ + ":dataset_ops_kernels", + "//tensorflow/contrib/ignite:dataset_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":gen_dataset_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + ], +) + +# The Apache Ignite servers have to setup before the test and tear down +# after the test manually. The docker engine has to be installed. +# +# To setup Apache Ignite servers: +# $ bash ./python/tests/start_ignite.sh +# +# To tear down Apache Ignite servers: +# $ bash ./python/tests/stop_ignite.sh +tf_py_test( + name = "ignite_dataset_test", + srcs = ["python/tests/ignite_dataset_test.py"], + additional_deps = [ + ":ignite", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], + tags = [ + "manual", + "no_windows", + "notap", + ], +) diff --git a/tensorflow/contrib/ignite/README.md b/tensorflow/contrib/ignite/README.md new file mode 100644 index 0000000000000000000000000000000000000000..55c89d27996318dabb29bb15372411005301ebd9 --- /dev/null +++ b/tensorflow/contrib/ignite/README.md @@ -0,0 +1,167 @@ +# Ignite Dataset + +- [Overview](#overview) +- [Features](#features) + * [Distributed In-Memory Datasource](#distributed-in-memory-datasource) + * [Structured Objects](#structured-objects) + * [Distributed Training](#distributed-training) + * [SSL Connection](#ssl-connection) + * [Windows Support](#windows-support) +- [Try it out](#try-it-out) +- [Limitations](#limitations) + +## Overview + +[Apache Ignite](https://ignite.apache.org/) is a memory-centric distributed database, caching, and processing platform for +transactional, analytical, and streaming workloads, delivering in-memory speeds at petabyte scale. This contrib package contains an integration between Apache Ignite and TensorFlow. The integration is based on [tf.data](https://www.tensorflow.org/api_docs/python/tf/data) from TensorFlow side and [Binary Client Protocol](https://apacheignite.readme.io/v2.6/docs/binary-client-protocol) from Apache Ignite side. It allows to use Apache Ignite as a data source for neural network training, inference and all other computations supported by TensorFlow. + +## Features + +Ignite Dataset provides features that that you can use in a wide range of cases. The most important and interesting features are described below. + +### Distributed In-Memory Datasource +[Apache Ignite](https://ignite.apache.org/) is a distributed in-memory database, caching, and processing platform that provides fast data access. It allows you to avoid limitations of hard drive and store and operate with as much data as you need in distributed cluster. You can utilize +these benefits of Apache Ignite by using Ignite Dataset. Moreover, Ignite Dataset can be used for the following use-cases: +- If you have a **gigabyte** of data you can keep it on a single machine on a hard drive, but you will face with hard drive speed limitations. At the same time, you can store your data in Apache Ignite on the same machine and use it as a datasource for TensorFlow and thus avoid these limitations. +- If you have a **terabyte** of data you probably still can keep it on a single machine on a hard drive, but you will face with hard drive speed limitations again. At the same time, you can store your data in Apache Ignite distributed in-memory cluster and use it as a datasource for TensorFlow and thus avoid these limitations. +- If you have a **petabyte** of data you can't keep it on a single machine. At the same time, you can store your data in Apache Ignite distributed in-memory cluster and use it as a datasource for TensorFlow. + +Note that Apache Ignite is not just a step of ETL pipeline between a database or a data warehouse and TensorFlow. Apache Ignite is a high-grade database itself. By choosing Apache Ignite and TensorFlow you are getting everything you need to work with operational or historical data and, at the same time, an ability to use this data for neural network training and inference. + +```bash +$ apache-ignite-fabric/bin/ignite.sh +$ apache-ignite-fabric/bin/sqlline.sh -u "jdbc:ignite:thin://localhost:10800/" + +jdbc:ignite:thin://localhost/> CREATE TABLE KITTEN_CACHE (ID LONG PRIMARY KEY, NAME VARCHAR); +jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (1, 'WARM KITTY'); +jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (2, 'SOFT KITTY'); +jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (3, 'LITTLE BALL OF FUR'); +``` + +```python +>>> import tensorflow as tf +>>> from tensorflow.contrib.ignite import IgniteDataset +>>> +>>> dataset = IgniteDataset(cache_name="SQL_PUBLIC_KITTEN_CACHE") +>>> iterator = dataset.make_one_shot_iterator() +>>> next_obj = iterator.get_next() +>>> +>>> with tf.Session() as sess: +>>> for _ in range(3): +>>> print(sess.run(next_obj)) + +{'key': 1, 'val': {'NAME': b'WARM KITTY'}} +{'key': 2, 'val': {'NAME': b'SOFT KITTY'}} +{'key': 3, 'val': {'NAME': b'LITTLE BALL OF FUR'}} +``` + +### Structured Objects +[Apache Ignite](https://ignite.apache.org/) allows to store any type of objects. These objects can have any hierarchy. Ignite Dataset provides an ability to work with such objects. + +```python +>>> import tensorflow as tf +>>> from tensorflow.contrib.ignite import IgniteDataset +>>> +>>> dataset = IgniteDataset(cache_name="IMAGES") +>>> iterator = dataset.make_one_shot_iterator() +>>> next_obj = iterator.get_next() +>>> +>>> with tf.Session() as sess: +>>> print(sess.run(next_obj)) + +{ + 'key': 'kitten.png', + 'val': { + 'metadata': { + 'file_name': b'kitten.png', + 'label': b'little ball of fur', + width: 800, + height: 600 + }, + 'pixels': [0, 0, 0, 0, ..., 0] + } +} +``` + Neural network training and other computations require transformations that can be done as part of [tf.data](https://www.tensorflow.org/api_docs/python/tf/data) pipeline if you use Ignite Dataset. + +```python +>>> import tensorflow as tf +>>> from tensorflow.contrib.ignite import IgniteDataset +>>> +>>> dataset = IgniteDataset(cache_name="IMAGES").map(lambda obj: obj['val']['pixels']) +>>> iterator = dataset.make_one_shot_iterator() +>>> next_obj = iterator.get_next() +>>> +>>> with tf.Session() as sess: +>>> print(sess.run(next_obj)) + +[0, 0, 0, 0, ..., 0] +``` + +### Distributed Training + +TensorFlow is a machine learning framework that [natively supports](https://www.tensorflow.org/deploy/distributed) distributed neural network training, inference and other computations. The main idea behind the distributed neural network training is the ability to calculate gradients of loss functions (squares of the errors) on every partition of data (in terms of horizontal partitioning) and then sum them to get loss function gradient of the whole dataset. + + + +Using this ability we can calculate gradients on the nodes the data is stored on, reduce them and then finally update model parameters. It allows to avoid data transfers between nodes and thus to avoid network bottlenecks. + +Apache Ignite uses horizontal partitioning to store data in distributed cluster. When we create Apache Ignite cache (or table in terms of SQL), we can specify the number of partitions the data will be partitioned on. For example, if an Apache Ignite cluster consists of 10 machines and we create cache with 10 partitions, then every machine will maintain approximately one data partition. + +Ignite Dataset allows using these two aspects of distributed neural network training (using TensorFlow) and Apache Ignite partitioning. Ignite Dataset is a computation graph operation that can be performed on a remote worker. The remote worker can override Ignite Dataset parameters (such as `host`, `port` or `part`) by setting correstondent environment variables for worker process (such as `IGNITE_DATASET_HOST`, `IGNITE_DATASET_PORT` or `IGNITE_DATASET_PART`). Using this overriding approach, we can assign a specific partition to every worker so that one worker handles one partition and, at the same time, transparently work with single dataset. + +```python +>>> import tensorflow as tf +>>> from tensorflow.contrib.ignite import IgniteDataset +>>> +>>> dataset = IgniteDataset("IMAGES") +>>> +>>> # Compute gradients locally on every worker node. +>>> gradients = [] +>>> for i in range(5): +>>> with tf.device("/job:WORKER/task:%d" % i): +>>> device_iterator = dataset.make_one_shot_iterator() +>>> device_next_obj = device_iterator.get_next() +>>> gradient = compute_gradient(device_next_obj) +>>> gradients.append(gradient) +>>> +>>> # Aggregate them on master node. +>>> result_gradient = tf.reduce_sum(gradients) +>>> +>>> with tf.Session("grpc://localhost:10000") as sess: +>>> print(sess.run(result_gradient)) +``` + +High-level TensorFlow API for [distributed training](https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy) is supported as well. + +### SSL Connection + +Apache Ignite allows to protect data transfer channels by [SSL](https://en.wikipedia.org/wiki/Transport_Layer_Security) and authentification. Ignite Dataset supports both SSL connection with and without authntication. For more information, please refer to the [Apache Ignite SSL/TLS](https://apacheignite.readme.io/docs/ssltls) documentation. + +```python +>>> import tensorflow as tf +>>> from tensorflow.contrib.ignite import IgniteDataset +>>> +>>> dataset = IgniteDataset(cache_name="IMAGES", certfile="client.pem", cert_password="password", username="ignite", password="ignite") +>>> ... +``` + +### Windows Support + +Ignite Dataset is fully compatible with Windows. You can use it as part of TensorFlow on your Windows workstation as well as on Linux/MacOS systems. + +## Try it out + +The simplest way to try Ignite Dataset is to run a [Docker](https://www.docker.com/) container with Apache Ignite and loaded [MNIST](http://yann.lecun.com/exdb/mnist/) data and after start interruct with it using Ignite Dataset. Such container is available on Docker Hub: [dmitrievanthony/ignite-with-mnist](https://hub.docker.com/r/dmitrievanthony/ignite-with-mnist/). You need to start this container on your machine: + +``` +docker run -it -p 10800:10800 dmitrievanthony/ignite-with-mnist +``` + +After that you will be able to work with it following way: + +![ignite-dataset-mnist](https://s3.amazonaws.com/helloworld23423423ew23/ignite-dataset-mnist.png "Ignite Dataset Mnist") + +## Limitations + +Presently, Ignite Dataset works with assumption that all objects in the cache have the same structure (homogeneous objects) and the cache contains at least one object. Another limitation concerns structured objects, Ignite Dataset does not support UUID, Maps and Object arrays that might be parts of an object structure. diff --git a/tensorflow/contrib/ignite/__init__.py b/tensorflow/contrib/ignite/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f42947696f76e168f77b2316758209f1f71a7915 --- /dev/null +++ b/tensorflow/contrib/ignite/__init__.py @@ -0,0 +1,42 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""IgniteDataset that allows to get data from Apache Ignite. + +Apache Ignite is a memory-centric distributed database, caching, and +processing platform for transactional, analytical, and streaming workloads, +delivering in-memory speeds at petabyte scale. This contrib package +contains an integration between Apache Ignite and TensorFlow. The +integration is based on tf.data from TensorFlow side and Binary Client +Protocol from Apache Ignite side. It allows to use Apache Ignite as a +datasource for neural network training, inference and all other +computations supported by TensorFlow. Ignite Dataset is based on Apache +Ignite Binary Client Protocol: +https://apacheignite.readme.io/v2.6/docs/binary-client-protocol. + +@@IgniteDataset +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.ignite.python.ops.ignite_dataset_ops import IgniteDataset +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + "IgniteDataset", +] + +remove_undocumented(__name__) diff --git a/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc b/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc new file mode 100644 index 0000000000000000000000000000000000000000..2c8a7d44b07b43f788bcbc0850b5162cc14dd951 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc @@ -0,0 +1,334 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +BinaryObjectParser::BinaryObjectParser() : byte_swapper_(ByteSwapper(false)) {} + +Status BinaryObjectParser::Parse(uint8_t** ptr, + std::vector* out_tensors, + std::vector* types) const { + uint8_t object_type_id = ParseByte(ptr); + + // Skip non-leaf nodes. + if (object_type_id != WRAPPED_OBJ && object_type_id != COMPLEX_OBJ) + types->push_back(object_type_id); + + switch (object_type_id) { + case BYTE: { + out_tensors->emplace_back(cpu_allocator(), DT_UINT8, TensorShape({})); + out_tensors->back().scalar()() = ParseByte(ptr); + break; + } + case SHORT: { + out_tensors->emplace_back(cpu_allocator(), DT_INT16, TensorShape({})); + out_tensors->back().scalar()() = ParseShort(ptr); + break; + } + case USHORT: { + out_tensors->emplace_back(cpu_allocator(), DT_UINT16, TensorShape({})); + out_tensors->back().scalar()() = ParseUnsignedShort(ptr); + break; + } + case INT: { + out_tensors->emplace_back(cpu_allocator(), DT_INT32, TensorShape({})); + out_tensors->back().scalar()() = ParseInt(ptr); + break; + } + case LONG: { + out_tensors->emplace_back(cpu_allocator(), DT_INT64, TensorShape({})); + out_tensors->back().scalar()() = ParseLong(ptr); + break; + } + case FLOAT: { + out_tensors->emplace_back(cpu_allocator(), DT_FLOAT, TensorShape({})); + out_tensors->back().scalar()() = ParseFloat(ptr); + break; + } + case DOUBLE: { + out_tensors->emplace_back(cpu_allocator(), DT_DOUBLE, TensorShape({})); + out_tensors->back().scalar()() = ParseDouble(ptr); + break; + } + case BOOL: { + out_tensors->emplace_back(cpu_allocator(), DT_BOOL, TensorShape({})); + out_tensors->back().scalar()() = ParseBool(ptr); + break; + } + case STRING: { + out_tensors->emplace_back(cpu_allocator(), DT_STRING, TensorShape({})); + out_tensors->back().scalar()() = ParseString(ptr); + break; + } + case DATE: { + out_tensors->emplace_back(cpu_allocator(), DT_INT64, TensorShape({})); + out_tensors->back().scalar()() = ParseLong(ptr); + break; + } + case BYTE_ARR: { + int32_t length = ParseInt(ptr); + uint8_t* arr = ParseByteArr(ptr, length); + out_tensors->emplace_back(cpu_allocator(), DT_UINT8, + TensorShape({length})); + std::copy_n(arr, length, out_tensors->back().flat().data()); + break; + } + case SHORT_ARR: { + int32_t length = ParseInt(ptr); + int16_t* arr = ParseShortArr(ptr, length); + out_tensors->emplace_back(cpu_allocator(), DT_INT16, + TensorShape({length})); + std::copy_n(arr, length, out_tensors->back().flat().data()); + break; + } + case USHORT_ARR: { + int32_t length = ParseInt(ptr); + uint16_t* arr = ParseUnsignedShortArr(ptr, length); + out_tensors->emplace_back(cpu_allocator(), DT_UINT16, + TensorShape({length})); + std::copy_n(arr, length, out_tensors->back().flat().data()); + break; + } + case INT_ARR: { + int32_t length = ParseInt(ptr); + int32_t* arr = ParseIntArr(ptr, length); + out_tensors->emplace_back(cpu_allocator(), DT_INT32, + TensorShape({length})); + std::copy_n(arr, length, out_tensors->back().flat().data()); + break; + } + case LONG_ARR: { + int32_t length = ParseInt(ptr); + int64_t* arr = ParseLongArr(ptr, length); + out_tensors->emplace_back(cpu_allocator(), DT_INT64, + TensorShape({length})); + std::copy_n(arr, length, out_tensors->back().flat().data()); + break; + } + case FLOAT_ARR: { + int32_t length = ParseInt(ptr); + float* arr = ParseFloatArr(ptr, length); + out_tensors->emplace_back(cpu_allocator(), DT_FLOAT, + TensorShape({length})); + std::copy_n(arr, length, out_tensors->back().flat().data()); + break; + } + case DOUBLE_ARR: { + int32_t length = ParseInt(ptr); + double* arr = ParseDoubleArr(ptr, length); + out_tensors->emplace_back(cpu_allocator(), DT_DOUBLE, + TensorShape({length})); + std::copy_n(arr, length, out_tensors->back().flat().data()); + break; + } + case BOOL_ARR: { + int32_t length = ParseInt(ptr); + bool* arr = ParseBoolArr(ptr, length); + out_tensors->emplace_back(cpu_allocator(), DT_BOOL, + TensorShape({length})); + std::copy_n(arr, length, out_tensors->back().flat().data()); + break; + } + case STRING_ARR: { + int32_t length = ParseInt(ptr); + out_tensors->emplace_back(cpu_allocator(), DT_STRING, + TensorShape({length})); + for (int32_t i = 0; i < length; i++) + out_tensors->back().vec()(i) = ParseString(ptr); + break; + } + case DATE_ARR: { + int32_t length = ParseInt(ptr); + int64_t* arr = ParseLongArr(ptr, length); + out_tensors->emplace_back(cpu_allocator(), DT_INT64, + TensorShape({length})); + std::copy_n(arr, length, out_tensors->back().flat().data()); + break; + } + case WRAPPED_OBJ: { + int32_t byte_arr_size = ParseInt(ptr); + TF_RETURN_IF_ERROR(Parse(ptr, out_tensors, types)); + int32_t offset = ParseInt(ptr); + + break; + } + case COMPLEX_OBJ: { + uint8_t version = ParseByte(ptr); + int16_t flags = ParseShort(ptr); + int32_t type_id = ParseInt(ptr); + int32_t hash_code = ParseInt(ptr); + int32_t length = ParseInt(ptr); + int32_t schema_id = ParseInt(ptr); + int32_t schema_offset = ParseInt(ptr); + + // 24 is size of header just read. + uint8_t* end = *ptr + schema_offset - 24; + int32_t i = 0; + while (*ptr < end) { + i++; + TF_RETURN_IF_ERROR(Parse(ptr, out_tensors, types)); + } + + *ptr += (length - schema_offset); + + break; + } + default: { + return errors::Unknown("Unknowd binary type (type id ", + (int)object_type_id, ")"); + } + } + + return Status::OK(); +} + +uint8_t BinaryObjectParser::ParseByte(uint8_t** ptr) const { + uint8_t res = **ptr; + *ptr += 1; + + return res; +} + +int16_t BinaryObjectParser::ParseShort(uint8_t** ptr) const { + int16_t* res = *reinterpret_cast(ptr); + byte_swapper_.SwapIfRequiredInt16(res); + *ptr += 2; + + return *res; +} + +uint16_t BinaryObjectParser::ParseUnsignedShort(uint8_t** ptr) const { + uint16_t* res = *reinterpret_cast(ptr); + byte_swapper_.SwapIfRequiredUnsignedInt16(res); + *ptr += 2; + + return *res; +} + +int32_t BinaryObjectParser::ParseInt(uint8_t** ptr) const { + int32_t* res = *reinterpret_cast(ptr); + byte_swapper_.SwapIfRequiredInt32(res); + *ptr += 4; + + return *res; +} + +int64_t BinaryObjectParser::ParseLong(uint8_t** ptr) const { + int64_t* res = *reinterpret_cast(ptr); + byte_swapper_.SwapIfRequiredInt64(res); + *ptr += 8; + + return *res; +} + +float BinaryObjectParser::ParseFloat(uint8_t** ptr) const { + float* res = *reinterpret_cast(ptr); + byte_swapper_.SwapIfRequiredFloat(res); + *ptr += 4; + + return *res; +} + +double BinaryObjectParser::ParseDouble(uint8_t** ptr) const { + double* res = *reinterpret_cast(ptr); + byte_swapper_.SwapIfRequiredDouble(res); + *ptr += 8; + + return *res; +} + +bool BinaryObjectParser::ParseBool(uint8_t** ptr) const { + bool res = **reinterpret_cast(ptr); + *ptr += 1; + + return res; +} + +string BinaryObjectParser::ParseString(uint8_t** ptr) const { + int32_t length = ParseInt(ptr); + string res(*reinterpret_cast(ptr), length); + *ptr += length; + + return res; +} + +uint8_t* BinaryObjectParser::ParseByteArr(uint8_t** ptr, int length) const { + uint8_t* res = *reinterpret_cast(ptr); + *ptr += length; + + return res; +} + +int16_t* BinaryObjectParser::ParseShortArr(uint8_t** ptr, int length) const { + int16_t* res = *reinterpret_cast(ptr); + byte_swapper_.SwapIfRequiredInt16Arr(res, length); + *ptr += length * 2; + + return res; +} + +uint16_t* BinaryObjectParser::ParseUnsignedShortArr(uint8_t** ptr, + int length) const { + uint16_t* res = *reinterpret_cast(ptr); + byte_swapper_.SwapIfRequiredUnsignedInt16Arr(res, length); + *ptr += length * 2; + + return res; +} + +int32_t* BinaryObjectParser::ParseIntArr(uint8_t** ptr, int length) const { + int32_t* res = *reinterpret_cast(ptr); + byte_swapper_.SwapIfRequiredInt32Arr(res, length); + *ptr += length * 4; + + return res; +} + +int64_t* BinaryObjectParser::ParseLongArr(uint8_t** ptr, int length) const { + int64_t* res = *reinterpret_cast(ptr); + byte_swapper_.SwapIfRequiredInt64Arr(res, length); + *ptr += length * 8; + + return res; +} + +float* BinaryObjectParser::ParseFloatArr(uint8_t** ptr, int length) const { + float* res = *reinterpret_cast(ptr); + byte_swapper_.SwapIfRequiredFloatArr(res, length); + *ptr += length * 4; + + return res; +} + +double* BinaryObjectParser::ParseDoubleArr(uint8_t** ptr, int length) const { + double* res = *reinterpret_cast(ptr); + byte_swapper_.SwapIfRequiredDoubleArr(res, length); + *ptr += length * 8; + + return res; +} + +bool* BinaryObjectParser::ParseBoolArr(uint8_t** ptr, int length) const { + bool* res = *reinterpret_cast(ptr); + *ptr += length; + + return res; +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h b/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h new file mode 100644 index 0000000000000000000000000000000000000000..eb1f856643a790de6acaa82d4b8ad894fd364376 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BINARY_OBJECT_PARSER_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BINARY_OBJECT_PARSER_H_ + +#include +#include "tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class BinaryObjectParser { + public: + BinaryObjectParser(); + Status Parse(uint8_t** ptr, std::vector* out_tensors, + std::vector* types) const; + + private: + uint8_t ParseByte(uint8_t** ptr) const; + int16_t ParseShort(uint8_t** ptr) const; + uint16_t ParseUnsignedShort(uint8_t** ptr) const; + int32_t ParseInt(uint8_t** ptr) const; + int64_t ParseLong(uint8_t** ptr) const; + float ParseFloat(uint8_t** ptr) const; + double ParseDouble(uint8_t** ptr) const; + bool ParseBool(uint8_t** ptr) const; + string ParseString(uint8_t** ptr) const; + uint8_t* ParseByteArr(uint8_t** ptr, int length) const; + int16_t* ParseShortArr(uint8_t** ptr, int length) const; + uint16_t* ParseUnsignedShortArr(uint8_t** ptr, int length) const; + int32_t* ParseIntArr(uint8_t** ptr, int length) const; + int64_t* ParseLongArr(uint8_t** ptr, int length) const; + float* ParseFloatArr(uint8_t** ptr, int length) const; + double* ParseDoubleArr(uint8_t** ptr, int length) const; + bool* ParseBoolArr(uint8_t** ptr, int length) const; + + const ByteSwapper byte_swapper_; +}; + +enum ObjectType { + BYTE = 1, + SHORT = 2, + INT = 3, + LONG = 4, + FLOAT = 5, + DOUBLE = 6, + USHORT = 7, + BOOL = 8, + STRING = 9, + DATE = 11, + BYTE_ARR = 12, + SHORT_ARR = 13, + INT_ARR = 14, + LONG_ARR = 15, + FLOAT_ARR = 16, + DOUBLE_ARR = 17, + USHORT_ARR = 18, + BOOL_ARR = 19, + STRING_ARR = 20, + DATE_ARR = 22, + WRAPPED_OBJ = 27, + COMPLEX_OBJ = 103 +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BINARY_OBJECT_PARSER_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h b/tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h new file mode 100644 index 0000000000000000000000000000000000000000..46df3e39dc4ec6dd4ef5730a184264eaa9fc5872 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h @@ -0,0 +1,126 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BYTE_SWAPPER_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BYTE_SWAPPER_H_ + +#include +#include "tensorflow/core/platform/byte_order.h" + +namespace tensorflow { + +class ByteSwapper { + public: + ByteSwapper(bool big_endian) { swap_ = big_endian == port::kLittleEndian; } + + inline void SwapIfRequiredInt16(int16_t *x) const { + if (swap_) { + Swap16(x); + } + } + + inline void SwapIfRequiredUnsignedInt16(uint16_t *x) const { + if (swap_) { + Swap16(reinterpret_cast(x)); + } + } + + inline void SwapIfRequiredInt32(int32_t *x) const { + if (swap_) { + Swap32(x); + } + } + + inline void SwapIfRequiredFloat(float *x) const { + if (swap_) { + Swap32(reinterpret_cast(x)); + } + } + + inline void SwapIfRequiredInt64(int64_t *x) const { + if (swap_) { + Swap64(x); + } + } + + inline void SwapIfRequiredDouble(double *x) const { + if (swap_) { + Swap64(reinterpret_cast(x)); + } + } + + inline void SwapIfRequiredInt16Arr(int16_t *x, int32_t length) const { + if (swap_) { + for (int32_t i = 0; i < length; i++) Swap16(&x[i]); + } + } + + inline void SwapIfRequiredUnsignedInt16Arr(uint16_t *x, + int32_t length) const { + if (swap_) { + for (int32_t i = 0; i < length; i++) + Swap16(reinterpret_cast(&x[i])); + } + } + + inline void SwapIfRequiredInt32Arr(int32_t *x, int32_t length) const { + if (swap_) { + for (int32_t i = 0; i < length; i++) Swap32(&x[i]); + } + } + + inline void SwapIfRequiredFloatArr(float *x, int32_t length) const { + if (swap_) { + for (int32_t i = 0; i < length; i++) + Swap32(reinterpret_cast(&x[i])); + } + } + + inline void SwapIfRequiredInt64Arr(int64_t *x, int32_t length) const { + if (swap_) { + for (int32_t i = 0; i < length; i++) Swap64(&x[i]); + } + } + + inline void SwapIfRequiredDoubleArr(double *x, int32_t length) const { + if (swap_) { + for (int32_t i = 0; i < length; i++) + Swap64(reinterpret_cast(&x[i])); + } + } + + private: + inline void Swap16(int16_t *x) const { + *x = ((*x & 0xFF) << 8) | ((*x >> 8) & 0xFF); + } + + inline void Swap32(int32_t *x) const { + *x = ((*x & 0xFF) << 24) | (((*x >> 8) & 0xFF) << 16) | + (((*x >> 16) & 0xFF) << 8) | ((*x >> 24) & 0xFF); + } + + inline void Swap64(int64_t *x) const { + *x = ((*x & 0xFF) << 56) | (((*x >> 8) & 0xFF) << 48) | + (((*x >> 16) & 0xFF) << 40) | (((*x >> 24) & 0xFF) << 32) | + (((*x >> 32) & 0xFF) << 24) | (((*x >> 40) & 0xFF) << 16) | + (((*x >> 48) & 0xFF) << 8) | ((*x >> 56) & 0xFF); + } + + bool swap_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BYTE_SWAPPER_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_client.h b/tensorflow/contrib/ignite/kernels/ignite_client.h new file mode 100644 index 0000000000000000000000000000000000000000..459b50b48fd95ad105bccaca4076160e0ef152ee --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_client.h @@ -0,0 +1,84 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_CLIENT_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_CLIENT_H_ + +#include "tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class Client { + public: + Client(bool big_endian) : byte_swapper_(ByteSwapper(big_endian)) {} + virtual Status Connect() = 0; + virtual Status Disconnect() = 0; + virtual bool IsConnected() = 0; + virtual int GetSocketDescriptor() = 0; + virtual Status ReadData(uint8_t *buf, const int32_t length) = 0; + virtual Status WriteData(const uint8_t *buf, const int32_t length) = 0; + + inline Status ReadByte(uint8_t *data) { return ReadData(data, 1); } + + inline Status ReadShort(int16_t *data) { + TF_RETURN_IF_ERROR(ReadData((uint8_t *)data, 2)); + byte_swapper_.SwapIfRequiredInt16(data); + + return Status::OK(); + } + + inline Status ReadInt(int32_t *data) { + TF_RETURN_IF_ERROR(ReadData((uint8_t *)data, 4)); + byte_swapper_.SwapIfRequiredInt32(data); + + return Status::OK(); + } + + inline Status ReadLong(int64_t *data) { + TF_RETURN_IF_ERROR(ReadData((uint8_t *)data, 8)); + byte_swapper_.SwapIfRequiredInt64(data); + + return Status::OK(); + } + + inline Status WriteByte(const uint8_t data) { return WriteData(&data, 1); } + + inline Status WriteShort(const int16_t data) { + int16_t tmp = data; + byte_swapper_.SwapIfRequiredInt16(&tmp); + return WriteData((uint8_t *)&tmp, 2); + } + + inline Status WriteInt(const int32_t data) { + int32_t tmp = data; + byte_swapper_.SwapIfRequiredInt32(&tmp); + return WriteData((uint8_t *)&tmp, 4); + } + + inline Status WriteLong(const int64_t data) { + int64_t tmp = data; + byte_swapper_.SwapIfRequiredInt64(&tmp); + return WriteData((uint8_t *)&tmp, 8); + } + + private: + const ByteSwapper byte_swapper_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_CLIENT_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset.cc b/tensorflow/contrib/ignite/kernels/ignite_dataset.cc new file mode 100644 index 0000000000000000000000000000000000000000..c4a7d3c513a796c9d95b371bedc609fd75188817 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_dataset.cc @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +IgniteDataset::IgniteDataset(OpKernelContext* ctx, string cache_name, + string host, int32 port, bool local, int32 part, + int32 page_size, string username, string password, + string certfile, string keyfile, + string cert_password, std::vector schema, + std::vector permutation, + DataTypeVector dtypes, + std::vector shapes) + : DatasetBase(DatasetContext(ctx)), + cache_name_(std::move(cache_name)), + host_(std::move(host)), + port_(port), + local_(local), + part_(part), + page_size_(page_size), + username_(std::move(username)), + password_(std::move(password)), + certfile_(std::move(certfile)), + keyfile_(std::move(keyfile)), + cert_password_(std::move(cert_password)), + schema_(std::move(schema)), + permutation_(std::move(permutation)), + dtypes_(dtypes), + shapes_(shapes) { + LOG(INFO) << "Ignite Dataset created [cache_name='" << cache_name_ + << "', host='" << host_ << "', port=" << port_ + << ", local=" << local_ << ", part=" << part_ + << ", page_size=" << page_size_ << ", username='" << username_ + << "', certfile='" << certfile_ << "', keyfile='" + << keyfile_ + "']"; +} + +IgniteDataset::~IgniteDataset() { LOG(INFO) << "Ignite Dataset destroyed"; } + +std::unique_ptr IgniteDataset::MakeIteratorInternal( + const string& prefix) const { + return std::unique_ptr(new IgniteDatasetIterator( + {this, strings::StrCat(prefix, "::Ignite")}, std::move(this->host_), + this->port_, std::move(this->cache_name_), this->local_, this->part_, + this->page_size_, std::move(this->username_), std::move(this->password_), + std::move(this->certfile_), std::move(this->keyfile_), + std::move(this->cert_password_), std::move(this->schema_), + std::move(this->permutation_))); +} + +const DataTypeVector& IgniteDataset::output_dtypes() const { return dtypes_; } + +const std::vector& IgniteDataset::output_shapes() const { + return shapes_; +} + +string IgniteDataset::DebugString() const { return "IgniteDatasetOp::Dataset"; } + +Status IgniteDataset::AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const { + return errors::Unimplemented( + "IgniteDataset does not support 'AsGraphDefInternal'"); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset.h b/tensorflow/contrib/ignite/kernels/ignite_dataset.h new file mode 100644 index 0000000000000000000000000000000000000000..66bfdf2e2a168e59cd2fec8e2ac5b8fd482d5c15 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_dataset.h @@ -0,0 +1,63 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { + +class IgniteDataset : public DatasetBase { + public: + IgniteDataset(OpKernelContext* ctx, string cache_name, string host, + int32 port, bool local, int32 part, int32 page_size, + string username, string password, string certfile, + string keyfile, string cert_password, std::vector schema, + std::vector permutation, DataTypeVector dtypes, + std::vector shapes); + ~IgniteDataset(); + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override; + const DataTypeVector& output_dtypes() const override; + const std::vector& output_shapes() const override; + string DebugString() const override; + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override; + + private: + const string cache_name_; + const string host_; + const int32 port_; + const bool local_; + const int32 part_; + const int32 page_size_; + const string username_; + const string password_; + const string certfile_; + const string keyfile_; + const string cert_password_; + const std::vector schema_; + const std::vector permutation_; + const DataTypeVector dtypes_; + const std::vector shapes_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc new file mode 100644 index 0000000000000000000000000000000000000000..5da9127aa6a3a4bc16347e6890cc1ba44406c0d5 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc @@ -0,0 +1,422 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h" + +#include "tensorflow/contrib/ignite/kernels/ignite_plain_client.h" +#include "tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +IgniteDatasetIterator::IgniteDatasetIterator( + const Params& params, string host, int32 port, string cache_name, + bool local, int32 part, int32 page_size, string username, string password, + string certfile, string keyfile, string cert_password, + std::vector schema, std::vector permutation) + : DatasetIterator(params), + cache_name_(std::move(cache_name)), + local_(local), + part_(part), + page_size_(page_size), + username_(std::move(username)), + password_(std::move(password)), + schema_(std::move(schema)), + permutation_(std::move(permutation)), + remainder_(-1), + cursor_id_(-1), + last_page_(false), + valid_state_(true) { + Client* p_client = new PlainClient(std::move(host), port, false); + + if (certfile.empty()) + client_ = std::unique_ptr(p_client); + else + client_ = std::unique_ptr( + new SslWrapper(std::unique_ptr(p_client), std::move(certfile), + std::move(keyfile), std::move(cert_password), false)); + + LOG(INFO) << "Ignite Dataset Iterator created"; +} + +IgniteDatasetIterator::~IgniteDatasetIterator() { + Status status = CloseConnection(); + if (!status.ok()) LOG(ERROR) << status.ToString(); + + LOG(INFO) << "Ignite Dataset Iterator destroyed"; +} + +Status IgniteDatasetIterator::GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) { + mutex_lock l(mutex_); + + if (valid_state_) { + Status status = + GetNextInternalWithValidState(ctx, out_tensors, end_of_sequence); + + if (!status.ok()) valid_state_ = false; + + return status; + } + + return errors::Unknown("Iterator is invalid"); +} + +Status IgniteDatasetIterator::SaveInternal(IteratorStateWriter* writer) { + return errors::Unimplemented( + "Iterator for IgniteDataset does not support 'SaveInternal'"); +} + +Status IgniteDatasetIterator::RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) { + return errors::Unimplemented( + "Iterator for IgniteDataset does not support 'RestoreInternal')"); +} + +Status IgniteDatasetIterator::GetNextInternalWithValidState( + IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) { + if (remainder_ == 0 && last_page_) { + cursor_id_ = -1; + *end_of_sequence = true; + + return Status::OK(); + } else { + TF_RETURN_IF_ERROR(EstablishConnection()); + + if (remainder_ == -1) { + TF_RETURN_IF_ERROR(ScanQuery()); + } else if (remainder_ == 0) { + TF_RETURN_IF_ERROR(LoadNextPage()); + } + + uint8_t* initial_ptr = ptr_; + std::vector tensors; + std::vector types; + + TF_RETURN_IF_ERROR(parser_.Parse(&ptr_, &tensors, &types)); // Parse key + TF_RETURN_IF_ERROR(parser_.Parse(&ptr_, &tensors, &types)); // Parse val + + remainder_ -= (ptr_ - initial_ptr); + + TF_RETURN_IF_ERROR(CheckTypes(types)); + + for (size_t i = 0; i < tensors.size(); i++) + out_tensors->push_back(tensors[permutation_[i]]); + + *end_of_sequence = false; + + return Status::OK(); + } + + *end_of_sequence = true; + + return Status::OK(); +} + +Status IgniteDatasetIterator::EstablishConnection() { + if (!client_->IsConnected()) { + TF_RETURN_IF_ERROR(client_->Connect()); + + Status status = Handshake(); + if (!status.ok()) { + Status disconnect_status = client_->Disconnect(); + if (!disconnect_status.ok()) LOG(ERROR) << disconnect_status.ToString(); + + return status; + } + } + + return Status::OK(); +} + +Status IgniteDatasetIterator::CloseConnection() { + if (cursor_id_ != -1 && !last_page_) { + TF_RETURN_IF_ERROR(EstablishConnection()); + + TF_RETURN_IF_ERROR(client_->WriteInt(kCloseConnectionReqLength)); + TF_RETURN_IF_ERROR(client_->WriteShort(kCloseConnectionOpcode)); + TF_RETURN_IF_ERROR(client_->WriteLong(0)); // Request ID + TF_RETURN_IF_ERROR(client_->WriteLong(cursor_id_)); // Resource ID + + int32_t res_len; + TF_RETURN_IF_ERROR(client_->ReadInt(&res_len)); + if (res_len < kMinResLength) + return errors::Unknown("Close Resource Response is corrupted"); + + int64_t req_id; + TF_RETURN_IF_ERROR(client_->ReadLong(&req_id)); + int32_t status; + TF_RETURN_IF_ERROR(client_->ReadInt(&status)); + if (status != 0) { + uint8_t err_msg_header; + TF_RETURN_IF_ERROR(client_->ReadByte(&err_msg_header)); + if (err_msg_header == kStringVal) { + int32_t err_msg_length; + TF_RETURN_IF_ERROR(client_->ReadInt(&err_msg_length)); + + uint8_t* err_msg_c = new uint8_t[err_msg_length]; + auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; }); + TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, err_msg_length)); + string err_msg(reinterpret_cast(err_msg_c), err_msg_length); + + return errors::Unknown("Close Resource Error [status=", status, + ", message=", err_msg, "]"); + } + return errors::Unknown("Close Resource Error [status=", status, "]"); + } + + cursor_id_ = -1; + + return client_->Disconnect(); + } else { + LOG(INFO) << "Query Cursor " << cursor_id_ << " is already closed"; + } + + return client_->IsConnected() ? client_->Disconnect() : Status::OK(); +} + +Status IgniteDatasetIterator::Handshake() { + int32_t msg_len = kHandshakeReqDefaultLength; + + if (username_.empty()) + msg_len += 1; + else + msg_len += 5 + username_.length(); // 1 byte header, 4 bytes length. + + if (password_.empty()) + msg_len += 1; + else + msg_len += 5 + password_.length(); // 1 byte header, 4 bytes length. + + TF_RETURN_IF_ERROR(client_->WriteInt(msg_len)); + TF_RETURN_IF_ERROR(client_->WriteByte(1)); + TF_RETURN_IF_ERROR(client_->WriteShort(kProtocolMajorVersion)); + TF_RETURN_IF_ERROR(client_->WriteShort(kProtocolMinorVersion)); + TF_RETURN_IF_ERROR(client_->WriteShort(kProtocolPatchVersion)); + TF_RETURN_IF_ERROR(client_->WriteByte(2)); + if (username_.empty()) { + TF_RETURN_IF_ERROR(client_->WriteByte(kNullVal)); + } else { + TF_RETURN_IF_ERROR(client_->WriteByte(kStringVal)); + TF_RETURN_IF_ERROR(client_->WriteInt(username_.length())); + TF_RETURN_IF_ERROR( + client_->WriteData(reinterpret_cast(username_.c_str()), + username_.length())); + } + + if (password_.empty()) { + TF_RETURN_IF_ERROR(client_->WriteByte(kNullVal)); + } else { + TF_RETURN_IF_ERROR(client_->WriteByte(kStringVal)); + TF_RETURN_IF_ERROR(client_->WriteInt(password_.length())); + TF_RETURN_IF_ERROR( + client_->WriteData(reinterpret_cast(password_.c_str()), + password_.length())); + } + + int32_t handshake_res_len; + TF_RETURN_IF_ERROR(client_->ReadInt(&handshake_res_len)); + uint8_t handshake_res; + TF_RETURN_IF_ERROR(client_->ReadByte(&handshake_res)); + + if (handshake_res != 1) { + int16_t serv_ver_major; + TF_RETURN_IF_ERROR(client_->ReadShort(&serv_ver_major)); + int16_t serv_ver_minor; + TF_RETURN_IF_ERROR(client_->ReadShort(&serv_ver_minor)); + int16_t serv_ver_patch; + TF_RETURN_IF_ERROR(client_->ReadShort(&serv_ver_patch)); + uint8_t header; + TF_RETURN_IF_ERROR(client_->ReadByte(&header)); + + if (header == kStringVal) { + int32_t length; + TF_RETURN_IF_ERROR(client_->ReadInt(&length)); + + uint8_t* err_msg_c = new uint8_t[length]; + auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; }); + TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, length)); + string err_msg(reinterpret_cast(err_msg_c), length); + + return errors::Unknown("Handshake Error [result=", handshake_res, + ", version=", serv_ver_major, ".", serv_ver_minor, + ".", serv_ver_patch, ", message='", err_msg, "']"); + } else if (header == kNullVal) { + return errors::Unknown("Handshake Error [result=", handshake_res, + ", version=", serv_ver_major, ".", serv_ver_minor, + ".", serv_ver_patch, "]"); + } else { + return errors::Unknown("Handshake Error [result=", handshake_res, + ", version=", serv_ver_major, ".", serv_ver_minor, + ".", serv_ver_patch, "]"); + } + } + + return Status::OK(); +} + +Status IgniteDatasetIterator::ScanQuery() { + TF_RETURN_IF_ERROR(client_->WriteInt(kScanQueryReqLength)); + TF_RETURN_IF_ERROR(client_->WriteShort(kScanQueryOpcode)); + TF_RETURN_IF_ERROR(client_->WriteLong(0)); // Request ID + TF_RETURN_IF_ERROR( + client_->WriteInt(JavaHashCode(cache_name_))); // Cache name + TF_RETURN_IF_ERROR(client_->WriteByte(0)); // Flags + TF_RETURN_IF_ERROR(client_->WriteByte(kNullVal)); // Filter object + TF_RETURN_IF_ERROR(client_->WriteInt(page_size_)); // Cursor page size + TF_RETURN_IF_ERROR(client_->WriteInt(part_)); // part_ition to query + TF_RETURN_IF_ERROR(client_->WriteByte(local_)); // local_ flag + + uint64 wait_start = Env::Default()->NowMicros(); + int32_t res_len; + TF_RETURN_IF_ERROR(client_->ReadInt(&res_len)); + int64_t wait_stop = Env::Default()->NowMicros(); + + LOG(INFO) << "Scan Query waited " << (wait_stop - wait_start) / 1000 << " ms"; + + if (res_len < kMinResLength) + return errors::Unknown("Scan Query Response is corrupted"); + + int64_t req_id; + TF_RETURN_IF_ERROR(client_->ReadLong(&req_id)); + + int32_t status; + TF_RETURN_IF_ERROR(client_->ReadInt(&status)); + + if (status != 0) { + uint8_t err_msg_header; + TF_RETURN_IF_ERROR(client_->ReadByte(&err_msg_header)); + + if (err_msg_header == kStringVal) { + int32_t err_msg_length; + TF_RETURN_IF_ERROR(client_->ReadInt(&err_msg_length)); + + uint8_t* err_msg_c = new uint8_t[err_msg_length]; + auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; }); + TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, err_msg_length)); + string err_msg(reinterpret_cast(err_msg_c), err_msg_length); + + return errors::Unknown("Scan Query Error [status=", status, + ", message=", err_msg, "]"); + } + return errors::Unknown("Scan Query Error [status=", status, "]"); + } + + TF_RETURN_IF_ERROR(client_->ReadLong(&cursor_id_)); + + int32_t row_cnt; + TF_RETURN_IF_ERROR(client_->ReadInt(&row_cnt)); + + int32_t page_size = res_len - kScanQueryResHeaderLength; + + return ReceivePage(page_size); +} + +Status IgniteDatasetIterator::LoadNextPage() { + TF_RETURN_IF_ERROR(client_->WriteInt(kLoadNextPageReqLength)); + TF_RETURN_IF_ERROR(client_->WriteShort(kLoadNextPageOpcode)); + TF_RETURN_IF_ERROR(client_->WriteLong(0)); // Request ID + TF_RETURN_IF_ERROR(client_->WriteLong(cursor_id_)); // Cursor ID + + uint64 wait_start = Env::Default()->NowMicros(); + int32_t res_len; + TF_RETURN_IF_ERROR(client_->ReadInt(&res_len)); + uint64 wait_stop = Env::Default()->NowMicros(); + + LOG(INFO) << "Load Next Page waited " << (wait_stop - wait_start) / 1000 + << " ms"; + + if (res_len < kMinResLength) + return errors::Unknown("Load Next Page Response is corrupted"); + + int64_t req_id; + TF_RETURN_IF_ERROR(client_->ReadLong(&req_id)); + + int32_t status; + TF_RETURN_IF_ERROR(client_->ReadInt(&status)); + + if (status != 0) { + uint8_t err_msg_header; + TF_RETURN_IF_ERROR(client_->ReadByte(&err_msg_header)); + + if (err_msg_header == kStringVal) { + int32_t err_msg_length; + TF_RETURN_IF_ERROR(client_->ReadInt(&err_msg_length)); + + uint8_t* err_msg_c = new uint8_t[err_msg_length]; + auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; }); + TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, err_msg_length)); + string err_msg(reinterpret_cast(err_msg_c), err_msg_length); + + return errors::Unknown("Load Next Page Error [status=", status, + ", message=", err_msg, "]"); + } + return errors::Unknown("Load Next Page Error [status=", status, "]"); + } + + int32_t row_cnt; + TF_RETURN_IF_ERROR(client_->ReadInt(&row_cnt)); + + int32_t page_size = res_len - kLoadNextPageResHeaderLength; + + return ReceivePage(page_size); +} + +Status IgniteDatasetIterator::ReceivePage(int32_t page_size) { + remainder_ = page_size; + page_ = std::unique_ptr(new uint8_t[remainder_]); + ptr_ = page_.get(); + + uint64 start = Env::Default()->NowMicros(); + TF_RETURN_IF_ERROR(client_->ReadData(ptr_, remainder_)); + uint64 stop = Env::Default()->NowMicros(); + + double size_in_mb = 1.0 * remainder_ / 1024 / 1024; + double time_in_s = 1.0 * (stop - start) / 1000 / 1000; + LOG(INFO) << "Page size " << size_in_mb << " Mb, time " << time_in_s * 1000 + << " ms download speed " << size_in_mb / time_in_s << " Mb/sec"; + + uint8_t last_page_b; + TF_RETURN_IF_ERROR(client_->ReadByte(&last_page_b)); + + last_page_ = !last_page_b; + + return Status::OK(); +} + +Status IgniteDatasetIterator::CheckTypes(const std::vector& types) { + if (schema_.size() != types.size()) + return errors::Unknown("Object has unexpected schema"); + + for (size_t i = 0; i < schema_.size(); i++) { + if (schema_[i] != types[permutation_[i]]) + return errors::Unknown("Object has unexpected schema"); + } + + return Status::OK(); +} + +int32_t IgniteDatasetIterator::JavaHashCode(string str) const { + int32_t h = 0; + for (char& c : str) { + h = 31 * h + c; + } + return h; +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..c499e2c9ccfac5c15db08c8fd8b26c37aa0404f3 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h @@ -0,0 +1,99 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_ITERATOR_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_ITERATOR_H_ + +#include "tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h" +#include "tensorflow/contrib/ignite/kernels/ignite_client.h" +#include "tensorflow/contrib/ignite/kernels/ignite_dataset.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +class IgniteDatasetIterator : public DatasetIterator { + public: + IgniteDatasetIterator(const Params& params, string host, int32 port, + string cache_name, bool local, int32 part, + int32 page_size, string username, string password, + string certfile, string keyfile, string cert_password, + std::vector schema, + std::vector permutation); + ~IgniteDatasetIterator(); + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) override; + + protected: + Status SaveInternal(IteratorStateWriter* writer) override; + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override; + + private: + Status GetNextInternalWithValidState(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence); + + Status EstablishConnection(); + Status CloseConnection(); + Status Handshake(); + Status ScanQuery(); + Status LoadNextPage(); + Status ReceivePage(int32_t page_size); + Status CheckTypes(const std::vector& types); + int32_t JavaHashCode(string str) const; + + std::unique_ptr client_; + BinaryObjectParser parser_; + + const string cache_name_; + const bool local_; + const int32 part_; + const int32 page_size_; + const string username_; + const string password_; + const std::vector schema_; + const std::vector permutation_; + + int32_t remainder_; + int64_t cursor_id_; + bool last_page_; + + bool valid_state_; + + mutex mutex_; + + std::unique_ptr page_; + uint8_t* ptr_; +}; + +constexpr uint8_t kNullVal = 101; +constexpr uint8_t kStringVal = 9; +constexpr uint8_t kProtocolMajorVersion = 1; +constexpr uint8_t kProtocolMinorVersion = 1; +constexpr uint8_t kProtocolPatchVersion = 0; +constexpr int16_t kScanQueryOpcode = 2000; +constexpr int16_t kLoadNextPageOpcode = 2001; +constexpr int16_t kCloseConnectionOpcode = 0; +constexpr int32_t kScanQueryReqLength = 25; +constexpr int32_t kScanQueryResHeaderLength = 25; +constexpr int32_t kLoadNextPageReqLength = 18; +constexpr int32_t kLoadNextPageResHeaderLength = 17; +constexpr int32_t kCloseConnectionReqLength = 18; +constexpr int32_t kHandshakeReqDefaultLength = 8; +constexpr int32_t kMinResLength = 12; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_ITERATOR_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc b/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..f75b1c5ff55ca9ee493148ff79c2edd4b15ac42a --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc @@ -0,0 +1,198 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h" +#include "tensorflow/contrib/ignite/kernels/ignite_dataset.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/lib/strings/numbers.h" + +namespace tensorflow { +namespace { + +Status SchemaToTypes(const std::vector& schema, DataTypeVector* dtypes) { + for (auto e : schema) { + if (e == BYTE || e == BYTE_ARR) { + dtypes->push_back(DT_UINT8); + } else if (e == SHORT || e == SHORT_ARR) { + dtypes->push_back(DT_INT16); + } else if (e == INT || e == INT_ARR) { + dtypes->push_back(DT_INT32); + } else if (e == LONG || e == LONG_ARR) { + dtypes->push_back(DT_INT64); + } else if (e == FLOAT || e == FLOAT_ARR) { + dtypes->push_back(DT_FLOAT); + } else if (e == DOUBLE || e == DOUBLE_ARR) { + dtypes->push_back(DT_DOUBLE); + } else if (e == USHORT || e == USHORT_ARR) { + dtypes->push_back(DT_UINT8); + } else if (e == BOOL || e == BOOL_ARR) { + dtypes->push_back(DT_BOOL); + } else if (e == STRING || e == STRING_ARR) { + dtypes->push_back(DT_STRING); + } else { + return errors::Unknown("Unexpected type in schema [type_id=", e, "]"); + } + } + + return Status::OK(); +} + +Status SchemaToShapes(const std::vector& schema, + std::vector* shapes) { + for (auto e : schema) { + if (e >= 1 && e < 10) { + shapes->push_back(PartialTensorShape({})); + } else if (e >= 12 && e < 21) { + shapes->push_back(PartialTensorShape({-1})); + } else { + return errors::Unknown("Unexpected type in schema [type_id=", e, "]"); + } + } + + return Status::OK(); +} + +class IgniteDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + string cache_name = ""; + string host = ""; + int32 port = -1; + bool local = false; + int32 part = -1; + int32 page_size = -1; + string username = ""; + string password = ""; + string certfile = ""; + string keyfile = ""; + string cert_password = ""; + + const char* env_cache_name = std::getenv("IGNITE_DATASET_CACHE_NAME"); + const char* env_host = std::getenv("IGNITE_DATASET_HOST"); + const char* env_port = std::getenv("IGNITE_DATASET_PORT"); + const char* env_local = std::getenv("IGNITE_DATASET_LOCAL"); + const char* env_part = std::getenv("IGNITE_DATASET_PART"); + const char* env_page_size = std::getenv("IGNITE_DATASET_PAGE_SIZE"); + const char* env_username = std::getenv("IGNITE_DATASET_USERNAME"); + const char* env_password = std::getenv("IGNITE_DATASET_PASSWORD"); + const char* env_certfile = std::getenv("IGNITE_DATASET_CERTFILE"); + const char* env_keyfile = std::getenv("IGNITE_DATASET_KEYFILE"); + const char* env_cert_password = std::getenv("IGNITE_DATASET_CERT_PASSWORD"); + + if (env_cache_name) { + cache_name = string(env_cache_name); + } else { + OP_REQUIRES_OK( + ctx, ParseScalarArgument(ctx, "cache_name", &cache_name)); + } + + if (env_host) { + host = string(env_host); + } else { + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "host", &host)); + } + + if (env_port) { + OP_REQUIRES(ctx, strings::safe_strto32(env_port, &port), + errors::InvalidArgument("IGNITE_DATASET_PORT environment " + "variable is not a valid integer: ", + env_port)); + } else { + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "port", &port)); + } + + if (env_local) { + local = true; + } else { + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "local", &local)); + } + + if (env_part) { + OP_REQUIRES(ctx, strings::safe_strto32(env_part, &part), + errors::InvalidArgument("IGNITE_DATASET_PART environment " + "variable is not a valid integer: ", + env_part)); + } else { + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "part", &part)); + } + + if (env_page_size) { + OP_REQUIRES(ctx, strings::safe_strto32(env_page_size, &page_size), + errors::InvalidArgument("IGNITE_DATASET_PAGE_SIZE " + "environment variable is not a valid " + "integer: ", + env_page_size)); + } else { + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "page_size", &page_size)); + } + + if (env_username) username = string(env_username); + + if (env_password) password = string(env_password); + + if (env_certfile) certfile = string(env_certfile); + + if (env_keyfile) keyfile = string(env_keyfile); + + if (env_cert_password) cert_password = string(env_cert_password); + + const Tensor* schema_tensor; + OP_REQUIRES_OK(ctx, ctx->input("schema", &schema_tensor)); + OP_REQUIRES(ctx, schema_tensor->dims() == 1, + errors::InvalidArgument("`schema` must be a vector.")); + + std::vector schema; + schema.reserve(schema_tensor->NumElements()); + for (int i = 0; i < schema_tensor->NumElements(); i++) { + schema.push_back(schema_tensor->flat()(i)); + } + + const Tensor* permutation_tensor; + OP_REQUIRES_OK(ctx, ctx->input("permutation", &permutation_tensor)); + OP_REQUIRES(ctx, permutation_tensor->dims() == 1, + errors::InvalidArgument("`permutation` must be a vector.")); + + std::vector permutation; + permutation.resize(permutation_tensor->NumElements()); + for (int i = 0; i < permutation_tensor->NumElements(); i++) { + // Inversed permutation. + permutation[permutation_tensor->flat()(i)] = i; + } + + DataTypeVector dtypes; + std::vector shapes; + + OP_REQUIRES_OK(ctx, SchemaToTypes(schema, &dtypes)); + OP_REQUIRES_OK(ctx, SchemaToShapes(schema, &shapes)); + + *output = new IgniteDataset( + ctx, std::move(cache_name), std::move(host), port, local, part, + page_size, std::move(username), std::move(password), + std::move(certfile), std::move(keyfile), std::move(cert_password), + std::move(schema), std::move(permutation), std::move(dtypes), + std::move(shapes)); + } +}; + +REGISTER_KERNEL_BUILDER(Name("IgniteDataset").Device(DEVICE_CPU), + IgniteDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client.h b/tensorflow/contrib/ignite/kernels/ignite_plain_client.h new file mode 100644 index 0000000000000000000000000000000000000000..75424c19ee4b7df5378aa23cb41db1752e8d0651 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_plain_client.h @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_PLAIN_CLIENT_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_PLAIN_CLIENT_H_ + +#include "tensorflow/contrib/ignite/kernels/ignite_client.h" + +namespace tensorflow { + +class PlainClient : public Client { + public: + PlainClient(string host, int port, bool big_endian); + ~PlainClient(); + + Status Connect() override; + Status Disconnect() override; + bool IsConnected() override; + int GetSocketDescriptor() override; + Status ReadData(uint8_t* buf, const int32_t length) override; + Status WriteData(const uint8_t* buf, const int32_t length) override; + + private: + const string host_; + const int port_; + int sock_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_PLAIN_CLIENT_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc b/tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc new file mode 100644 index 0000000000000000000000000000000000000000..cf672942c61e1239332711db12e62088737c4f41 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc @@ -0,0 +1,123 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/ignite_plain_client.h" + +#include +#include +#include +#include + +#include +#include + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +PlainClient::PlainClient(string host, int port, bool big_endian) + : Client(big_endian), host_(std::move(host)), port_(port), sock_(-1) {} + +PlainClient::~PlainClient() { + if (IsConnected()) { + Status status = Disconnect(); + if (!status.ok()) LOG(WARNING) << status.ToString(); + } +} + +Status PlainClient::Connect() { + if (sock_ == -1) { + sock_ = socket(AF_INET, SOCK_STREAM, 0); + if (sock_ == -1) return errors::Internal("Failed to create socket"); + } + + sockaddr_in server; + + server.sin_addr.s_addr = inet_addr(host_.c_str()); + if (server.sin_addr.s_addr == -1) { + hostent* he; + in_addr** addr_list; + + if ((he = gethostbyname(host_.c_str())) == NULL) + return errors::Internal("Failed to resolve hostname \"", host_, "\""); + + addr_list = (in_addr**)he->h_addr_list; + if (addr_list[0] != NULL) server.sin_addr = *addr_list[0]; + } + + server.sin_family = AF_INET; + server.sin_port = htons(port_); + + if (connect(sock_, (sockaddr*)&server, sizeof(server)) < 0) + return errors::Internal("Failed to connect to \"", host_, ":", port_, "\""); + + LOG(INFO) << "Connection to \"" << host_ << ":" << port_ << "\" established"; + + return Status::OK(); +} + +Status PlainClient::Disconnect() { + int close_res = close(sock_); + sock_ = -1; + + LOG(INFO) << "Connection to \"" << host_ << ":" << port_ << "\" is closed"; + + return close_res == 0 + ? Status::OK() + : errors::Internal("Failed to correctly close connection"); +} + +bool PlainClient::IsConnected() { return sock_ != -1; } + +int PlainClient::GetSocketDescriptor() { return sock_; } + +Status PlainClient::ReadData(uint8_t* buf, const int32_t length) { + int received = 0; + + while (received < length) { + int res = recv(sock_, buf, length - received, 0); + + if (res < 0) + return errors::Internal("Error occurred while reading from socket: ", res, + ", ", string(strerror(errno))); + + if (res == 0) return errors::Internal("Server closed connection"); + + received += res; + buf += res; + } + + return Status::OK(); +} + +Status PlainClient::WriteData(const uint8_t* buf, const int32_t length) { + int sent = 0; + + while (sent < length) { + int res = send(sock_, buf, length - sent, 0); + + if (res < 0) + return errors::Internal("Error occurred while writing into socket: ", res, + ", ", string(strerror(errno))); + + sent += res; + buf += res; + } + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc b/tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc new file mode 100644 index 0000000000000000000000000000000000000000..dad5aace5fabe1df58bb9579bf578f4c35324315 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc @@ -0,0 +1,142 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/ignite_plain_client.h" + +#define WIN32_LEAN_AND_MEAN +#include +#include +#include + +#pragma comment(lib, "Ws2_32.lib") +#pragma comment(lib, "Mswsock.lib") +#pragma comment(lib, "AdvApi32.lib") + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +PlainClient::PlainClient(string host, int port, bool big_endian) + : Client(big_endian), + host_(std::move(host)), + port_(port), + sock_(INVALID_SOCKET) {} + +PlainClient::~PlainClient() { + if (IsConnected()) { + Status status = Disconnect(); + if (!status.ok()) LOG(WARNING) << status.ToString(); + } +} + +Status PlainClient::Connect() { + WSADATA wsaData; + addrinfo *result = NULL, *ptr = NULL, hints; + + int res = WSAStartup(MAKEWORD(2, 2), &wsaData); + if (res != 0) return errors::Internal("WSAStartup failed with error: ", res); + + ZeroMemory(&hints, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; + + res = getaddrinfo(host_.c_str(), std::to_string(port_).c_str(), &hints, + &result); + if (res != 0) return errors::Internal("Getaddrinfo failed with error: ", res); + + auto clean = gtl::MakeCleanup([result] { freeaddrinfo(result); }); + + for (ptr = result; ptr != NULL; ptr = ptr->ai_next) { + sock_ = socket(ptr->ai_family, ptr->ai_socktype, ptr->ai_protocol); + if (sock_ == INVALID_SOCKET) { + WSACleanup(); + return errors::Internal("Socket failed with error: ", WSAGetLastError()); + } + + res = connect(sock_, ptr->ai_addr, (int)ptr->ai_addrlen); + if (res == SOCKET_ERROR) { + closesocket(sock_); + sock_ = INVALID_SOCKET; + continue; + } + + break; + } + + if (sock_ == INVALID_SOCKET) { + WSACleanup(); + return errors::Internal("Unable to connect to server"); + } + + LOG(INFO) << "Connection to \"" << host_ << ":" << port_ << "\" established"; + + return Status::OK(); +} + +Status PlainClient::Disconnect() { + int res = shutdown(sock_, SD_SEND); + closesocket(sock_); + WSACleanup(); + + if (res == SOCKET_ERROR) + return errors::Internal("Shutdown failed with error: ", WSAGetLastError()); + else + return Status::OK(); +} + +bool PlainClient::IsConnected() { return sock_ != INVALID_SOCKET; } + +int PlainClient::GetSocketDescriptor() { return sock_; } + +Status PlainClient::ReadData(uint8_t *buf, const int32_t length) { + int received = 0; + + while (received < length) { + int res = recv(sock_, (char *)buf, length - received, 0); + + if (res < 0) + return errors::Internal("Error occurred while reading from socket: ", + res); + + if (res == 0) return errors::Internal("Server closed connection"); + + received += res; + buf += res; + } + + return Status::OK(); +} + +Status PlainClient::WriteData(const uint8_t *buf, const int32_t length) { + int sent = 0; + + while (sent < length) { + int res = send(sock_, (char *)buf, length - sent, 0); + + if (res < 0) + return errors::Internal("Error occurred while writing into socket: ", + res); + + sent += res; + buf += res; + } + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc new file mode 100644 index 0000000000000000000000000000000000000000..ceb479b0846574a35d86002ebb9c3e8e1d3687ac --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc @@ -0,0 +1,151 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h" + +#include +#include + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +static int PasswordCb(char *buf, int size, int rwflag, void *password) { + strncpy(buf, (char *)(password), size); + buf[size - 1] = '\0'; + return (strlen(buf)); +} + +SslWrapper::SslWrapper(std::shared_ptr client, string certfile, + string keyfile, string cert_password, bool big_endian) + : Client(big_endian), + client_(client), + certfile_(std::move(certfile)), + keyfile_(std::move(keyfile)), + cert_password_(std::move(cert_password)), + ctx_(nullptr), + ssl_(nullptr) {} + +SslWrapper::~SslWrapper() { + if (IsConnected()) { + Status status = Disconnect(); + if (!status.ok()) LOG(WARNING) << status.ToString(); + } + + if (ctx_ != nullptr) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + + if (ssl_ != nullptr) { + SSL_free(ssl_); + ssl_ = nullptr; + } +} + +Status SslWrapper::InitSslContext() { + OpenSSL_add_all_algorithms(); + SSL_load_error_strings(); + + ctx_ = SSL_CTX_new(SSLv23_method()); + if (ctx_ == NULL) return errors::Internal("Couldn't create SSL context"); + + SSL_CTX_set_default_passwd_cb(ctx_, PasswordCb); + SSL_CTX_set_default_passwd_cb_userdata(ctx_, (void *)cert_password_.c_str()); + + if (SSL_CTX_use_certificate_chain_file(ctx_, certfile_.c_str()) != 1) + return errors::Internal("Couldn't load cetificate chain (file '", certfile_, + "')"); + + string private_key_file = keyfile_.empty() ? certfile_ : keyfile_; + if (SSL_CTX_use_PrivateKey_file(ctx_, private_key_file.c_str(), + SSL_FILETYPE_PEM) != 1) + return errors::Internal("Couldn't load private key (file '", + private_key_file, "')"); + + return Status::OK(); +} + +Status SslWrapper::Connect() { + if (ctx_ == NULL) { + TF_RETURN_IF_ERROR(InitSslContext()); + } + + ssl_ = SSL_new(ctx_); + if (ssl_ == NULL) + return errors::Internal("Failed to establish SSL connection"); + + TF_RETURN_IF_ERROR(client_->Connect()); + + SSL_set_fd(ssl_, client_->GetSocketDescriptor()); + if (SSL_connect(ssl_) != 1) + return errors::Internal("Failed to establish SSL connection"); + + LOG(INFO) << "SSL connection established"; + + return Status::OK(); +} + +Status SslWrapper::Disconnect() { + SSL_free(ssl_); + ssl_ = nullptr; + + LOG(INFO) << "SSL connection closed"; + + return client_->Disconnect(); +} + +bool SslWrapper::IsConnected() { return client_->IsConnected(); } + +int SslWrapper::GetSocketDescriptor() { return client_->GetSocketDescriptor(); } + +Status SslWrapper::ReadData(uint8_t *buf, const int32_t length) { + int received = 0; + + while (received < length) { + int res = SSL_read(ssl_, buf, length - received); + + if (res < 0) + return errors::Internal("Error occurred while reading from SSL socket: ", + res); + + if (res == 0) return errors::Internal("Server closed SSL connection"); + + received += res; + buf += res; + } + + return Status::OK(); +} + +Status SslWrapper::WriteData(const uint8_t *buf, const int32_t length) { + int sent = 0; + + while (sent < length) { + int res = SSL_write(ssl_, buf, length - sent); + + if (res < 0) + return errors::Internal("Error occurred while writing into socket: ", + res); + + sent += res; + buf += res; + } + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..0406644bbaab3de816540ce85e84b489ea9fff12 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h @@ -0,0 +1,51 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_SSL_WRAPPER_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_SSL_WRAPPER_H_ + +#include "tensorflow/contrib/ignite/kernels/ignite_client.h" + +#include + +namespace tensorflow { + +class SslWrapper : public Client { + public: + SslWrapper(std::shared_ptr client, string certfile, string keyfile, + string cert_password, bool big_endian); + ~SslWrapper(); + + Status Connect() override; + Status Disconnect() override; + bool IsConnected() override; + int GetSocketDescriptor() override; + Status ReadData(uint8_t* buf, const int32_t length) override; + Status WriteData(const uint8_t* buf, const int32_t length) override; + + private: + Status InitSslContext(); + + std::shared_ptr client_; + string certfile_; + string keyfile_; + string cert_password_; + SSL_CTX* ctx_; + SSL* ssl_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_SSL_WRAPPER_H_ diff --git a/tensorflow/contrib/ignite/ops/dataset_ops.cc b/tensorflow/contrib/ignite/ops/dataset_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..3d6fbe00e6296941b4ce77d1238a79099bb9a5aa --- /dev/null +++ b/tensorflow/contrib/ignite/ops/dataset_ops.cc @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("IgniteDataset") + .Input("cache_name: string") + .Input("host: string") + .Input("port: int32") + .Input("local: bool") + .Input("part: int32") + .Input("page_size: int32") + .Input("schema: int32") + .Input("permutation: int32") + .Output("handle: variant") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +IgniteDataset that allows to get data from Apache Ignite. + +Apache Ignite is a memory-centric distributed database, caching, and processing +platform for transactional, analytical, and streaming workloads, delivering +in-memory speeds at petabyte scale. This contrib package contains an +integration between Apache Ignite and TensorFlow. The integration is based on +tf.data from TensorFlow side and Binary Client Protocol from Apache Ignite side. +It allows to use Apache Ignite as a datasource for neural network training, +inference and all other computations supported by TensorFlow. Ignite Dataset +is based on Apache Ignite Binary Client Protocol. + +cache_name: Ignite Cache Name. +host: Ignite Thin Client Host. +port: Ignite Thin Client Port. +local: Local flag that defines that data should be fetched from local host only. +part: Partition data should be fetched from. +page_size: Page size for Ignite Thin Client. +schema: Internal structure that defines schema of cache objects. +permutation: Internal structure that defines permutation of cache objects. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..288d4853207176b215cd8a0cdcbfb2de5791ecb8 --- /dev/null +++ b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py @@ -0,0 +1,772 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ignite Dataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import socket +import ssl +import struct + +from tensorflow.contrib.ignite.python.ops import gen_dataset_ops +from tensorflow.contrib.ignite.python.ops import ignite_op_loader # pylint: disable=unused-import +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape + + +class Readable(object): + """Readable abstract class that exposes methods to do reading-related + + operations. + """ + + @abc.abstractmethod + def __init__(self): + pass + + def read_byte(self): + """Reads and returnes byte.""" + return self._read("b", 1) + + def read_short(self): + """Reads and returns short (2 bytes, little-endian).""" + return self._read("h", 2) + + def read_int(self): + """Reads and returns int (4 bytes, little-endian).""" + return self._read("i", 4) + + def read_long(self): + """Reads and returns long (8 bytes, little-endian).""" + return self._read("q", 8) + + def skip(self, length): + """Skips the specified number of bytes.""" + self.read_data(length) + + @abc.abstractmethod + def read_data(self, length): + """Reads the specified number of bytes and returns them as a buffer.""" + return None + + def _read(self, data_type, length): + """Reads, unpacks and returns specified type (little-endian).""" + data_buffer = self.read_data(length) + return struct.unpack("<" + data_type, data_buffer)[0] + + +class DataBuffer(Readable): + """DataBuffer class that exposes methods to read data from a byte buffer.""" + + def __init__(self, data_buffer): + """Constructs a new instance based on the specified byte buffer. + + Args: + data_buffer: Buffer to be read. + """ + Readable.__init__(self) + self.buffer = data_buffer + self.ptr = 0 + + def read_data(self, length): + """Reads the specified number of bytes and returns them as a buffer.""" + data_buffer = self.buffer[self.ptr:][:length] + self.ptr += length + return data_buffer + + +class TcpClient(Readable): + """TcpClient class that exposes methods to read data from a socket.""" + + def __init__(self, host, port, certfile=None, keyfile=None, password=None): + """Constructs a new instance based on the specified host and port. + + Args: + host: Host to be connected. + port: Port to be connected. + certfile: File in PEM format containing the certificate as well as any + number of CA certificates needed to establish the certificate's + authenticity. + keyfile: File containing the private key (otherwise the private key will + be taken from certfile as well). + password: Password to be used if the private key is encrypted and a + password is necessary. + + Raises: + ValueError: If the wrong combination of arguments is provided. + """ + Readable.__init__(self) + self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + + if certfile is not None: + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.load_cert_chain(certfile, keyfile, password) + self.sock = context.wrap_socket(self.sock) + else: + if keyfile is not None: + raise ValueError("SSL is disabled, keyfile must not be specified " + "(to enable SSL specify certfile)") + if password is not None: + raise ValueError("SSL is disabled, password must not be specified " + "(to enable SSL specify certfile)") + + self.host = host + self.port = port + + def __enter__(self): + """Connects to host and port specified in the constructor.""" + self.sock.connect((self.host, self.port)) + return self + + def __exit__(self, t, v, traceback): + """Disconnects the socket.""" + self.sock.close() + + def write_byte(self, v): + """Writes the specified byte.""" + self._write(v, "b") + + def write_short(self, v): + """Writes the specified short (2 bytes, little-endian).""" + self._write(v, "h") + + def write_int(self, v): + """Writes the specified short (4 bytes, little-endian).""" + self._write(v, "i") + + def write_long(self, v): + """Writes the specified int (8 bytes, little-endian).""" + self._write(v, "q") + + def write_string(self, v): + """Writes the specified string.""" + self.sock.sendall(v.encode("UTF-8")) + + def read_data(self, length): + """Reads the specified number of bytes and returns them as a buffer.""" + data_buffer = None + rem = length + while rem > 0: + buf = self.sock.recv(rem) + rem = rem - len(buf) + if data_buffer is None: + data_buffer = buf + else: + data_buffer += buf + return data_buffer + + def _write(self, value, data_type): + """Packs and writes data using the specified type (little-endian).""" + data_buffer = struct.pack("<" + data_type, value) + self.sock.sendall(data_buffer) + + +class BinaryType(object): + """BinaryType class that encapsulated type id, type name and fields.""" + + def __init__(self, type_id, type_name, fields): + """Constructs a new instance of BinaryType.""" + self.type_id = type_id + self.type_name = type_name + self.fields = fields + + +class BinaryField(object): + """BinaryField class that encapsulated field name, type id and field id.""" + + def __init__(self, field_name, type_id, field_id): + """Constructs a new instance of BinaryField.""" + self.field_name = field_name + self.type_id = type_id + self.field_id = field_id + + +# Binary types defined in Apache Ignite Thin client and supported by +# TensorFlow on Apache Ignite, see +# https://apacheignite.readme.io/v2.6/docs/binary-client-protocol. +# True means that type is a vector, False means type is scalar. +types = { + 1: (dtypes.uint8, False), + 2: (dtypes.int16, False), + 3: (dtypes.int32, False), + 4: (dtypes.int64, False), + 5: (dtypes.float32, False), + 6: (dtypes.float64, False), + 7: (dtypes.uint16, False), + 8: (dtypes.bool, False), + 9: (dtypes.string, False), + 12: (dtypes.uint8, True), + 13: (dtypes.int16, True), + 14: (dtypes.int32, True), + 15: (dtypes.int64, True), + 16: (dtypes.float32, True), + 17: (dtypes.float64, True), + 18: (dtypes.uint16, True), + 19: (dtypes.bool, True), + 20: (dtypes.string, True) +} + + +class TypeTreeNode(object): + """TypeTreeNode class exposes methods to format object tree structure + + data. + """ + + def __init__(self, name, type_id, fields=None, permutation=None): + """Constructs a new instance of TypeTreeNode. + + Args: + name: Name of the object tree node. + type_id: Type id of the object tree node. + fields: List of fields (children of the object tree node). + permutation: Permutation that should be applied to order object children. + """ + self.name = name + self.type_id = type_id + self.fields = fields + self.permutation = permutation + + def to_output_classes(self): + """Formats the tree object as required by `Dataset.output_classes`.""" + if self.fields is None: + return ops.Tensor + output_classes = {} + for field in self.fields: + output_classes[field.name] = field.to_output_classes() + return output_classes + + def to_output_shapes(self): + """Formats the tree object as required by `Dataset.output_shapes`.""" + if self.fields is None: + if self.type_id in types: + object_type = types[self.type_id] + is_array = object_type[1] + if is_array: + return tensor_shape.TensorShape([None]) + return tensor_shape.TensorShape([]) + raise ValueError("Unsupported type [type_id=%d]" % self.type_id) + output_shapes = {} + for field in self.fields: + output_shapes[field.name] = field.to_output_shapes() + return output_shapes + + def to_output_types(self): + """Formats the tree object as required by `Dataset.output_types`.""" + if self.fields is None: + if self.type_id in types: + object_type = types[self.type_id] + return object_type[0] + raise ValueError("Unsupported type [type_id=%d]" % self.type_id) + else: + output_types = {} + for field in self.fields: + output_types[field.name] = field.to_output_types() + return output_types + + def to_flat(self): + """Returns a list of node types.""" + return self.to_flat_rec([]) + + def to_permutation(self): + """Returns a permutation that should be applied to order object leaves.""" + correct_order_dict = {} + self.traversal_rec(correct_order_dict, 0) + object_order = [] + self.traversal_permutation_rec(object_order) + return [correct_order_dict[o] for o in object_order] + + def to_flat_rec(self, flat): + """Formats a list of leaf node types in pre-order.""" + if self.fields is None: + flat.append(self.type_id) + else: + for field in self.fields: + field.to_flat_rec(flat) + return flat + + def traversal_permutation_rec(self, permutation): + """Collects nodes in accordance with permutation.""" + if self.fields is None: + permutation.append(self) + else: + for idx in self.permutation: + field = self.fields[idx] + field.traversal_permutation_rec(permutation) + + def traversal_rec(self, d, i): + """Collects nodes in pre-order traversal.""" + if self.fields is None: + d[self] = i + i += 1 + else: + for field in self.fields: + i = field.traversal_rec(d, i) + return i + + +class IgniteClient(TcpClient): + """IgniteClient enables working with Apache Ignite using a thin client. + + This client works with assumption that all object in the cache + have the same structure (homogeneous objects) and the cache contains at + least one object. + """ + + def __init__(self, + host, + port, + username=None, + password=None, + certfile=None, + keyfile=None, + cert_password=None): + """Constructs a new instance of IgniteClient. + + Args: + host: Apache Ignite Thin client host to be connected. + port: Apache Ignite Thin client port to be connected. + username: Apache Ignite Thin Client authentication username. + password: Apache Ignite Thin Client authentication password. + certfile: File in PEM format containing the certificate as well as any + number of CA certificates needed to establish the certificate's + authenticity. + keyfile: File containing the private key (otherwise the private key will + be taken from certfile as well). + cert_password: Password to be used if the private key is encrypted and a + password is necessary. + """ + TcpClient.__init__(self, host, port, certfile, keyfile, cert_password) + self.username = username + self.password = password + + def handshake(self): + """Makes a handshake after connect and before any other calls.""" + msg_len = 8 + + if self.username is None: + msg_len += 1 + else: + msg_len += 5 + len(self.username) + + if self.password is None: + msg_len += 1 + else: + msg_len += 5 + len(self.password) + + self.write_int(msg_len) # Message length + self.write_byte(1) # Handshake operation + self.write_short(1) # Version (1.1.0) + self.write_short(1) + self.write_short(0) + self.write_byte(2) # Thin client + + if self.username is None: # Username + self.write_byte(101) + else: + self.write_byte(9) + self.write_int(len(self.username)) + self.write_string(self.username) + + if self.password is None: # Password + self.write_byte(101) + else: + self.write_byte(9) + self.write_int(len(self.password)) + self.write_string(self.password) + + self.read_int() # Result length + res = self.read_byte() + + if res != 1: + serv_ver_major = self.read_short() + serv_ver_minor = self.read_short() + serv_ver_patch = self.read_short() + err_msg = self._parse_string() + if err_msg is None: + raise RuntimeError( + "Handshake Error [result=%d, version=%d.%d.%d]" % + (res, serv_ver_major, serv_ver_minor, serv_ver_patch)) + else: + raise RuntimeError( + "Handshake Error [result=%d, version=%d.%d.%d, message='%s']" % + (res, serv_ver_major, serv_ver_minor, serv_ver_patch, err_msg)) + + def get_cache_type(self, cache_name): + """Collects type information about objects stored in the specified cache.""" + cache_name_hash = self._java_hash_code(cache_name) + self.write_int(25) # Message length + self.write_short(2000) # Operation code + self.write_long(0) # Request ID + self.write_int(cache_name_hash) # Cache name + self.write_byte(0) # Flags + self.write_byte(101) # Filter (NULL) + self.write_int(1) # Cursor page size + self.write_int(-1) # Partition to query + self.write_byte(0) # Local flag + + result_length = self.read_int() + self.read_long() # Request id + status = self.read_int() + + if status != 0: + err_msg = self._parse_string() + if err_msg is None: + raise RuntimeError("Scan Query Error [status=%s]" % status) + else: + raise RuntimeError( + "Scan Query Error [status=%s, message='%s']" % (status, err_msg)) + + self.read_long() # Cursor id + row_count = self.read_int() + + if row_count == 0: + raise RuntimeError("Scan Query returned empty result, so it's " + "impossible to derive the cache type") + + payload = DataBuffer(self.read_data(result_length - 25)) + + self.read_byte() # Next page + + res = TypeTreeNode("root", 0, [ + self._collect_types("key", payload), + self._collect_types("val", payload) + ], [0, 1]) + + return res + + def _java_hash_code(self, s): + """Computes hash code of the specified string using Java code.""" + h = 0 + for c in s: + h = (31 * h + ord(c)) & 0xFFFFFFFF + return ((h + 0x80000000) & 0xFFFFFFFF) - 0x80000000 + + def _collect_types(self, field_name, data): + """Extracts type information from the specified object.""" + type_id = data.read_byte() + + # Byte scalar. + if type_id == 1: + data.skip(1) + return TypeTreeNode(field_name, type_id) + + # Short scalar. + if type_id == 2: + data.skip(2) + return TypeTreeNode(field_name, type_id) + + # Integer scalar. + if type_id == 3: + data.skip(4) + return TypeTreeNode(field_name, type_id) + + # Long scalar. + if type_id == 4: + data.skip(8) + return TypeTreeNode(field_name, type_id) + + # Float scalar. + if type_id == 5: + data.skip(4) + return TypeTreeNode(field_name, type_id) + + # Double scalar. + if type_id == 6: + data.skip(8) + return TypeTreeNode(field_name, type_id) + + # Char scalar. + if type_id == 7: + data.skip(2) + return TypeTreeNode(field_name, type_id) + + # Bool scalar. + if type_id == 8: + data.skip(1) + return TypeTreeNode(field_name, type_id) + + # String scalar. + if type_id == 9: + length = data.read_int() + data.skip(length) + return TypeTreeNode(field_name, type_id) + + # UUID scalar. + if type_id == 10: + data.skip(16) + return TypeTreeNode(field_name, type_id) + + # Date scalar. + if type_id == 11: + data.skip(8) + return TypeTreeNode(field_name, type_id) + + # Byte array. + if type_id == 12: + length = data.read_int() + data.skip(length) + return TypeTreeNode(field_name, type_id) + + # Short array. + if type_id == 13: + length = data.read_int() + data.skip(length * 2) + return TypeTreeNode(field_name, type_id) + + # Integer array. + if type_id == 14: + length = data.read_int() + data.skip(length * 4) + return TypeTreeNode(field_name, type_id) + + # Long array. + if type_id == 15: + length = data.read_int() + data.skip(length * 8) + return TypeTreeNode(field_name, type_id) + + # Float array. + if type_id == 16: + length = data.read_int() + data.skip(length * 4) + return TypeTreeNode(field_name, type_id) + + # Double array. + if type_id == 17: + length = data.read_int() + data.skip(length * 8) + return TypeTreeNode(field_name, type_id) + + # Char array. + if type_id == 18: + length = data.read_int() + data.skip(length * 2) + return TypeTreeNode(field_name, type_id) + + # Bool array. + if type_id == 19: + length = data.read_int() + data.skip(length) + return TypeTreeNode(field_name, type_id) + + # String array. + if type_id == 20: + length = data.read_int() + for _ in range(length): + header = data.read_byte() + if header == 9: + str_length = data.read_int() + data.skip(str_length) + elif header == 101: + pass + else: + raise RuntimeError( + "Unknown binary type when expected string [type_id=%d]" % header) + return TypeTreeNode(field_name, type_id) + + # UUID array. + if type_id == 21: + length = data.read_int() + data.skip(length * 16) # TODO(dmitrievanthony): support NULL values. + return TypeTreeNode(field_name, type_id) + + # Date array. + if type_id == 22: + length = data.read_int() + data.skip(length * 8) + return TypeTreeNode(field_name, type_id) + + # Wrapped Binary Object. + if type_id == 27: + length = data.read_int() + inner_data = data.read_data(length) + data.read_int() # Offset + return self._collect_types(field_name, DataBuffer(inner_data)) + + # Complex Object. + if type_id == 103: + data.read_byte() # Object version + data.read_short() # Object flags + obj_type_id = data.read_int() + data.read_int() # Object hash code + obj_length = data.read_int() + data.read_int() # Object schema id + obj_schema_offset = data.read_int() + + obj_type = self._get_type(obj_type_id) + children = [] + + for obj_field in obj_type.fields: + child = self._collect_types(obj_field.field_name, data) + children.append(child) + + children_sorted = sorted(children, key=lambda child: child.name) + permutation = [children_sorted.index(child) for child in children] + children = children_sorted + + data.skip(obj_length - obj_schema_offset) + + return TypeTreeNode(field_name, type_id, children, permutation) + + raise RuntimeError("Unknown binary type [type_id=%d]" % type_id) + + def _get_type(self, type_id): + """Queries Apache Ignite information about type by type id.""" + self.write_int(14) # Message length + self.write_short(3002) # Operation code + self.write_long(0) # Request ID + self.write_int(type_id) # Type ID + + self.read_int() # Result length + self.read_long() # Request id + status = self.read_int() + + if status != 0: + err_msg = self._parse_string() + if err_msg is None: + raise RuntimeError("Get Binary Type Error [status=%d, message='%s']" % + (status, err_msg)) + else: + raise RuntimeError("Get Binary Type Error [status=%d]" % status) + + binary_type_exists = self.read_byte() + + if binary_type_exists == 0: + raise RuntimeError("Binary type not found [type_id=%d] " % type_id) + + binary_type_id = self.read_int() + binary_type_name = self._parse_string() + self._parse_string() # Affinity field name + + fields = [] + for _ in range(self.read_int()): + field_name = self._parse_string() + field_type_id = self.read_int() + field_id = self.read_int() + + field = BinaryField(field_name, field_type_id, field_id) + fields.append(field) + + is_enum = self.read_byte() + if is_enum == 1: + raise RuntimeError("Enum fields are not supported yet") + + schema_cnt = self.read_int() + for _ in range(schema_cnt): + self.read_int() # Schema id + field_cnt = self.read_int() + self.skip(field_cnt * 4) + + return BinaryType(binary_type_id, binary_type_name, fields) + + def _parse_string(self): + """Parses string.""" + header = self.read_byte() + if header == 9: + length = self.read_int() + return self.read_data(length).decode("utf-8") + if header == 101: + return None + raise RuntimeError( + "Unknown binary type when expected string [type_id=%d]" % header) + + +class IgniteDataset(dataset_ops.DatasetSource): + """Apache Ignite is a memory-centric distributed database, caching, and + + processing platform for transactional, analytical, and streaming workloads, + delivering in-memory speeds at petabyte scale. This contrib package + contains an integration between Apache Ignite and TensorFlow. The + integration is based on tf.data from TensorFlow side and Binary Client + Protocol from Apache Ignite side. It allows to use Apache Ignite as a + datasource for neural network training, inference and all other + computations supported by TensorFlow. Ignite Dataset is based on Apache + Ignite Binary Client Protocol. + """ + + def __init__(self, + cache_name, + host="localhost", + port=10800, + local=False, + part=-1, + page_size=100, + username=None, + password=None, + certfile=None, + keyfile=None, + cert_password=None): + """Create a IgniteDataset. + + Args: + cache_name: Cache name to be used as datasource. + host: Apache Ignite Thin Client host to be connected. + port: Apache Ignite Thin Client port to be connected. + local: Local flag that defines to query only local data. + part: Number of partitions to be queried. + page_size: Apache Ignite Thin Client page size. + username: Apache Ignite Thin Client authentication username. + password: Apache Ignite Thin Client authentication password. + certfile: File in PEM format containing the certificate as well as any + number of CA certificates needed to establish the certificate's + authenticity. + keyfile: File containing the private key (otherwise the private key will + be taken from certfile as well). + cert_password: Password to be used if the private key is encrypted and a + password is necessary. + """ + super(IgniteDataset, self).__init__() + + with IgniteClient(host, port, username, password, certfile, keyfile, + cert_password) as client: + client.handshake() + self.cache_type = client.get_cache_type(cache_name) + + self.cache_name = ops.convert_to_tensor( + cache_name, dtype=dtypes.string, name="cache_name") + self.host = ops.convert_to_tensor(host, dtype=dtypes.string, name="host") + self.port = ops.convert_to_tensor(port, dtype=dtypes.int32, name="port") + self.local = ops.convert_to_tensor(local, dtype=dtypes.bool, name="local") + self.part = ops.convert_to_tensor(part, dtype=dtypes.int32, name="part") + self.page_size = ops.convert_to_tensor( + page_size, dtype=dtypes.int32, name="page_size") + self.schema = ops.convert_to_tensor( + self.cache_type.to_flat(), dtype=dtypes.int32, name="schema") + self.permutation = ops.convert_to_tensor( + self.cache_type.to_permutation(), + dtype=dtypes.int32, + name="permutation") + + def _as_variant_tensor(self): + return gen_dataset_ops.ignite_dataset(self.cache_name, self.host, self.port, + self.local, self.part, self.page_size, + self.schema, self.permutation) + + @property + def output_classes(self): + return self.cache_type.to_output_classes() + + @property + def output_shapes(self): + return self.cache_type.to_output_shapes() + + @property + def output_types(self): + return self.cache_type.to_output_types() diff --git a/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py b/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..c9af7386cf0a26ed1a950130aa36caa7fb831fd0 --- /dev/null +++ b/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py @@ -0,0 +1,24 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python helper for loading Ignite ops and kernels.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.util import loader +from tensorflow.python.platform import resource_loader + +_dataset_ops = loader.load_op_library( + resource_loader.get_path_to_datafile("../../_dataset_ops.so")) diff --git a/tensorflow/contrib/ignite/python/tests/bin/start-plain.sh b/tensorflow/contrib/ignite/python/tests/bin/start-plain.sh new file mode 100755 index 0000000000000000000000000000000000000000..f4607ce8adab38c27d040ad1118858d17b924a6a --- /dev/null +++ b/tensorflow/contrib/ignite/python/tests/bin/start-plain.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +nohup apache-ignite-fabric/bin/ignite.sh /data/config/ignite-config-plain.xml & +sleep 5 # Wait Apache Ignite to be started + +./apache-ignite-fabric/bin/sqlline.sh \ +-u "jdbc:ignite:thin://127.0.0.1/" \ +--run=/data/sql/init.sql + +tail -f nohup.out diff --git a/tensorflow/contrib/ignite/python/tests/config/ignite-config-plain.xml b/tensorflow/contrib/ignite/python/tests/config/ignite-config-plain.xml new file mode 100644 index 0000000000000000000000000000000000000000..d900174a8abb3987c380e4a1a193ed81295fb88c --- /dev/null +++ b/tensorflow/contrib/ignite/python/tests/config/ignite-config-plain.xml @@ -0,0 +1,39 @@ + + + + + + + + + + + + + 127.0.0.1 + + + + + + + + + diff --git a/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ef29b5f14a4b2fea2400ec4d56a7ad2cf44cf2cb --- /dev/null +++ b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py @@ -0,0 +1,82 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for IgniteDataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.ignite import IgniteDataset +from tensorflow.python.client import session +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class IgniteDatasetTest(test.TestCase): + """The Apache Ignite servers have to setup before the test and tear down + + after the test manually. The docker engine has to be installed. + + To setup Apache Ignite servers: + $ bash start_ignite.sh + + To tear down Apache Ignite servers: + $ bash stop_ignite.sh + """ + + def test_ignite_dataset_with_plain_client(self): + """Test Ignite Dataset with plain client. + + """ + self._clear_env() + ds = IgniteDataset(cache_name="SQL_PUBLIC_TEST_CACHE", port=42300) + self._check_dataset(ds) + + def _clear_env(self): + """Clears environment variables used by Ignite Dataset. + + """ + if "IGNITE_DATASET_USERNAME" in os.environ: + del os.environ["IGNITE_DATASET_USERNAME"] + if "IGNITE_DATASET_PASSWORD" in os.environ: + del os.environ["IGNITE_DATASET_PASSWORD"] + if "IGNITE_DATASET_CERTFILE" in os.environ: + del os.environ["IGNITE_DATASET_CERTFILE"] + if "IGNITE_DATASET_CERT_PASSWORD" in os.environ: + del os.environ["IGNITE_DATASET_CERT_PASSWORD"] + + def _check_dataset(self, dataset): + """Checks that dataset provides correct data.""" + self.assertEqual(dtypes.int64, dataset.output_types["key"]) + self.assertEqual(dtypes.string, dataset.output_types["val"]["NAME"]) + self.assertEqual(dtypes.int64, dataset.output_types["val"]["VAL"]) + + it = dataset.make_one_shot_iterator() + ne = it.get_next() + + with session.Session() as sess: + rows = [sess.run(ne), sess.run(ne), sess.run(ne)] + with self.assertRaises(errors.OutOfRangeError): + sess.run(ne) + + self.assertEqual({"key": 1, "val": {"NAME": b"TEST1", "VAL": 42}}, rows[0]) + self.assertEqual({"key": 2, "val": {"NAME": b"TEST2", "VAL": 43}}, rows[1]) + self.assertEqual({"key": 3, "val": {"NAME": b"TEST3", "VAL": 44}}, rows[2]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/ignite/python/tests/sql/init.sql b/tensorflow/contrib/ignite/python/tests/sql/init.sql new file mode 100644 index 0000000000000000000000000000000000000000..5a192aef17e22544e853cb78b4eb235beded42fe --- /dev/null +++ b/tensorflow/contrib/ignite/python/tests/sql/init.sql @@ -0,0 +1,20 @@ +-- Copyright 2018 The TensorFlow Authors. All Rights Reserved. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- ============================================================================== + +CREATE TABLE TEST_CACHE (ID LONG PRIMARY KEY, NAME VARCHAR, VAL LONG); + +INSERT INTO TEST_CACHE VALUES (1, 'TEST1', 42); +INSERT INTO TEST_CACHE VALUES (2, 'TEST2', 43); +INSERT INTO TEST_CACHE VALUES (3, 'TEST3', 44); diff --git a/tensorflow/contrib/ignite/python/tests/start_ignite.sh b/tensorflow/contrib/ignite/python/tests/start_ignite.sh new file mode 100755 index 0000000000000000000000000000000000000000..a67bd44f2fb0d654ba07f022a5070c68df8e2ede --- /dev/null +++ b/tensorflow/contrib/ignite/python/tests/start_ignite.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +IGNITE_VERSION=2.6.0 +SCRIPT_PATH="$( cd "$(dirname "$0")" ; pwd -P )" + +# Start Apache Ignite with plain client listener. +docker run -itd --name ignite-plain -p 42300:10800 \ +-v ${SCRIPT_PATH}:/data apacheignite/ignite:${IGNITE_VERSION} /data/bin/start-plain.sh diff --git a/tensorflow/contrib/ignite/python/tests/stop_ignite.sh b/tensorflow/contrib/ignite/python/tests/stop_ignite.sh new file mode 100755 index 0000000000000000000000000000000000000000..8f03dbd1ede61f548d3de9d9738f97667e75df3c --- /dev/null +++ b/tensorflow/contrib/ignite/python/tests/stop_ignite.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +docker rm -f ignite-plain +docker rm -f ignite-ssl +docker rm -f ignite-ssl-auth diff --git a/tensorflow/contrib/image/BUILD b/tensorflow/contrib/image/BUILD index da450480b30b548484e69c61c85667d6dd390417..c9d917fe20dbcef1aa4a8dae3db935bcef73b281 100755 --- a/tensorflow/contrib/image/BUILD +++ b/tensorflow/contrib/image/BUILD @@ -49,6 +49,7 @@ tf_kernel_library( "kernels/image_ops.h", ], deps = [ + ":image_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//third_party/eigen3", @@ -74,7 +75,6 @@ tf_custom_op_py_library( dso = [":python/ops/_image_ops.so"], kernels = [ ":image_ops_kernels", - ":image_ops_op_lib", ], srcs_version = "PY2AND3", deps = [ @@ -128,6 +128,26 @@ tf_custom_op_library( ], ) +tf_kernel_library( + name = "distort_image_ops_kernels", + srcs = [ + "kernels/adjust_hsv_in_yiq_op.cc", + "kernels/adjust_hsv_in_yiq_op.h", + ], + gpu_srcs = [ + "kernels/adjust_hsv_in_yiq_op_gpu.cu.cc", + "kernels/adjust_hsv_in_yiq_op.h", + ], + deps = [ + ":distort_image_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/kernels:gpu_util_hdrs", + "//third_party/eigen3", + ], + alwayslink = 1, +) + tf_cc_test( name = "adjust_hsv_in_yiq_op_test", size = "small", @@ -155,13 +175,16 @@ tf_gen_op_wrapper_py( deps = [":distort_image_ops_op_lib"], ) -py_library( +tf_custom_op_py_library( name = "distort_image_py", srcs = [ "__init__.py", "python/ops/distort_image_ops.py", ], - data = [":python/ops/_distort_image_ops.so"], + dso = [":python/ops/_distort_image_ops.so"], + kernels = [ + ":distort_image_ops_kernels", + ], srcs_version = "PY2AND3", deps = [ ":distort_image_ops", @@ -338,25 +361,36 @@ tf_gen_op_libs( op_lib_names = ["single_image_random_dot_stereograms_ops"], ) +tf_kernel_library( + name = "single_image_random_dot_stereograms_kernels", + srcs = [ + "kernels/single_image_random_dot_stereograms_ops.cc", + ], + deps = [ + ":single_image_random_dot_stereograms_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//third_party/eigen3", + ], +) + tf_gen_op_wrapper_py( name = "single_image_random_dot_stereograms_ops", deps = [":single_image_random_dot_stereograms_ops_op_lib"], ) -cc_library( +alias( name = "image_ops_cc", - srcs = ["ops/image_ops.cc"], - deps = [ - ":image_ops_kernels", - "//tensorflow/core:framework", - ], - alwayslink = 1, + actual = ":image_ops_op_lib", ) -py_library( +tf_custom_op_py_library( name = "single_image_random_dot_stereograms_py", srcs = glob(["python/ops/single*.py"]) + ["__init__.py"], - data = [":python/ops/_single_image_random_dot_stereograms.so"], + dso = [":python/ops/_single_image_random_dot_stereograms.so"], + kernels = [ + ":single_image_random_dot_stereograms_kernels", + ], srcs_version = "PY2AND3", deps = [ ":image_py", diff --git a/tensorflow/contrib/image/__init__.py b/tensorflow/contrib/image/__init__.py index f230d93da4a9c01e8dee47aa258d9c28499469f1..91b8e8d0f93c5ac6af0e1863ab309eb97525f6a0 100755 --- a/tensorflow/contrib/image/__init__.py +++ b/tensorflow/contrib/image/__init__.py @@ -58,6 +58,7 @@ from tensorflow.contrib.image.python.ops.distort_image_ops import adjust_hsv_in_ from tensorflow.contrib.image.python.ops.distort_image_ops import random_hsv_in_yiq from tensorflow.contrib.image.python.ops.image_ops import angles_to_projective_transforms +from tensorflow.contrib.image.python.ops.image_ops import bipartite_match from tensorflow.contrib.image.python.ops.image_ops import compose_transforms from tensorflow.contrib.image.python.ops.image_ops import connected_components from tensorflow.contrib.image.python.ops.image_ops import flat_transforms_to_matrices diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc index 370a8caf6a71cc09629a5e75fd9151ae3f0f3b6d..788bf04b28aaad5d532258e0946fd03111384c69 100644 --- a/tensorflow/contrib/image/kernels/image_ops.cc +++ b/tensorflow/contrib/image/kernels/image_ops.cc @@ -156,6 +156,7 @@ namespace functor { TF_CALL_uint8(DECLARE_FUNCTOR); TF_CALL_int32(DECLARE_FUNCTOR); TF_CALL_int64(DECLARE_FUNCTOR); +TF_CALL_half(DECLARE_FUNCTOR); TF_CALL_float(DECLARE_FUNCTOR); TF_CALL_double(DECLARE_FUNCTOR); @@ -175,6 +176,7 @@ TF_CALL_double(DECLARE_FUNCTOR); TF_CALL_uint8(REGISTER); TF_CALL_int32(REGISTER); TF_CALL_int64(REGISTER); +TF_CALL_half(REGISTER); TF_CALL_float(REGISTER); TF_CALL_double(REGISTER); diff --git a/tensorflow/contrib/image/kernels/image_ops.h b/tensorflow/contrib/image/kernels/image_ops.h index 6b63eed1303accc330293b3a44cdb9def7881666..7fac774d07fa8e07a0730ad018ba70e2c73a9cc5 100644 --- a/tensorflow/contrib/image/kernels/image_ops.h +++ b/tensorflow/contrib/image/kernels/image_ops.h @@ -71,14 +71,7 @@ class ProjectiveGenerator { (transform[3] * output_x + transform[4] * output_y + transform[5]) / projection; - // TODO(ringwalt): Add a fill value input. -#if (defined __CUDA_ARCH__) && (CUDART_VERSION < 8000) - // On CUDA versions previous to 8.0, only __shared__ variables - // could be declared as static in the device code. const T fill_value = T(0); -#else - static const T fill_value = T(0); -#endif switch (interpolation_) { case INTERPOLATION_NEAREST: // Switch the order of x and y again for indexing into the image. diff --git a/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc b/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc index 8743a5ff724a5000ed0376045340f9ceaaccbfd2..36b9a236a6ea48e3b27dac956c93aecee321e2b7 100644 --- a/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc +++ b/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc @@ -32,6 +32,7 @@ typedef Eigen::GpuDevice GPUDevice; template class FillProjectiveTransform; template class FillProjectiveTransform; template class FillProjectiveTransform; +template class FillProjectiveTransform; template class FillProjectiveTransform; template class FillProjectiveTransform; diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py index 376c0751eebb4906920ed338647630798d509113..4997c31a7fc7f4243d03b22fc9c01fb13a2a25a4 100644 --- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py @@ -272,6 +272,15 @@ class ImageOpsTest(test_util.TensorFlowTestCase): with self.cached_session(): self.assertAllEqual([[[[1], [0]], [[0], [1]]]], result.eval()) + def test_transform_data_types(self): + for dtype in _DTYPES: + image = constant_op.constant([[1, 2], [3, 4]], dtype=dtype) + value = image_ops.transform(image, [1] * 8) + with self.test_session(use_gpu=True): + self.assertAllEqual( + value.eval(), + np.array([[4, 4], [4, 4]]).astype(dtype.as_numpy_dtype())) + class BipartiteMatchTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh index adf027b8e714124cde2b4618546e20c6b7162e1f..69553c3bd15c9359a6ab879bc4e104bd5c30beac 100644 --- a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh +++ b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh @@ -22,8 +22,12 @@ if [ "$#" -ne 2 ]; then exit 1 fi +action=$1 container=$2 -if [ "$1" == "start" ]; then +if [ "$action" == "start" ]; then + echo pull spotify/kafka + docker pull spotify/kafka + echo pull spotify/kafka successfully docker run -d --rm --net=host --name=$container spotify/kafka echo Wait 5 secs until kafka is up and running sleep 5 @@ -33,12 +37,10 @@ if [ "$1" == "start" ]; then docker exec $container bash -c 'echo -e "D0\nD1\nD2\nD3\nD4\nD5\nD6\nD7\nD8\nD9" > /test' echo Produce test message docker exec $container bash -c '/opt/kafka_2.11-0.10.1.0/bin/kafka-console-producer.sh --topic test --broker-list 127.0.0.1:9092 < /test' - echo Container $container started successfully -elif [ "$1" == "stop" ]; then +elif [ "$action" == "stop" ]; then docker rm -f $container - - echo Container $container stopped successfully + echo Container $container removed successfully else echo "Usage: $0 start|stop " >&2 exit 1 diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 3b7ae72e9c460ee7a38f72b03e1c1ad48e335f57..8ead6336a08db4dd52edf0d3372db5a50f860e2b 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -630,7 +630,7 @@ class ConvolutionTest(test.TestCase): expected_size = [None, num_filters, None, None] expected_size_dynamic = [5, num_filters, 7, 9] - with self.test_session(use_gpu=True): + with self.session(use_gpu=True): images = array_ops.placeholder(np.float32, [None, input_size[1], None, None]) output = layers_lib.convolution2d( @@ -721,7 +721,7 @@ class Convolution2dTransposeTests(test.TestCase): def testOutputSizeWithStrideOneSamePaddingNCHW(self): # `NCHW` data format is only supported for `GPU` device. if test.is_gpu_available(cuda_only=True): - with self.test_session(use_gpu=True) as sess: + with self.session(use_gpu=True) as sess: num_filters = 32 input_size = [5, 3, 10, 12] expected_size = [5, num_filters, 10, 12] @@ -740,7 +740,7 @@ class Convolution2dTransposeTests(test.TestCase): def testOutputSizeWithStrideOneValidPaddingNCHW(self): if test.is_gpu_available(cuda_only=True): - with self.test_session(use_gpu=True) as sess: + with self.session(use_gpu=True) as sess: num_filters = 32 input_size = [5, 3, 10, 12] expected_size = [5, num_filters, 12, 14] @@ -759,7 +759,7 @@ class Convolution2dTransposeTests(test.TestCase): def testOutputSizeWithStrideTwoValidPaddingNCHW(self): if test.is_gpu_available(cuda_only=True): - with self.test_session(use_gpu=True) as sess: + with self.session(use_gpu=True) as sess: num_filters = 32 input_size = [5, 3, 9, 11] expected_size = [5, num_filters, 19, 23] @@ -779,7 +779,7 @@ class Convolution2dTransposeTests(test.TestCase): def testOutputSizeWith1x1StrideTwoSamePaddingNCHW(self): if test.is_gpu_available(cuda_only=True): - with self.test_session(use_gpu=True) as sess: + with self.session(use_gpu=True) as sess: num_filters = 1 input_size = [1, 1, 1, 1] expected_size = [1, num_filters, 2, 2] @@ -799,7 +799,7 @@ class Convolution2dTransposeTests(test.TestCase): def testOutputSizeWith1x1StrideTwoValidPaddingNCHW(self): if test.is_gpu_available(cuda_only=True): - with self.test_session(use_gpu=True) as sess: + with self.session(use_gpu=True) as sess: num_filters = 1 input_size = [1, 1, 1, 1] expected_size = [1, num_filters, 2, 2] @@ -817,7 +817,7 @@ class Convolution2dTransposeTests(test.TestCase): def testOutputSizeWith2x2StrideTwoSamePaddingNCHW(self): if test.is_gpu_available(cuda_only=True): - with self.test_session(use_gpu=True) as sess: + with self.session(use_gpu=True) as sess: num_filters = 1 input_size = [1, 1, 2, 2] expected_size = [1, num_filters, 4, 4] @@ -835,7 +835,7 @@ class Convolution2dTransposeTests(test.TestCase): def testOutputSizeWith2x2StrideTwoValidPaddingNCHW(self): if test.is_gpu_available(cuda_only=True): - with self.test_session(use_gpu=True) as sess: + with self.session(use_gpu=True) as sess: num_filters = 1 input_size = [1, 1, 2, 2] expected_size = [1, num_filters, 4, 4] @@ -853,7 +853,7 @@ class Convolution2dTransposeTests(test.TestCase): def testOutputSizeWithStride2x1NCHW(self): if test.is_gpu_available(cuda_only=True): - with self.test_session(use_gpu=True) as sess: + with self.session(use_gpu=True) as sess: num_filters = 1 input_size = [1, 1, 3, 2] expected_size = [1, num_filters, 6, 5] @@ -871,7 +871,7 @@ class Convolution2dTransposeTests(test.TestCase): def testOutputSizeWithStride2x4NCHW(self): if test.is_gpu_available(cuda_only=True): - with self.test_session(use_gpu=True) as sess: + with self.session(use_gpu=True) as sess: num_filters = 1 input_size = [1, 1, 3, 2] expected_size = [1, num_filters, 6, 8] @@ -889,7 +889,7 @@ class Convolution2dTransposeTests(test.TestCase): def testOutputSizeWithStride2x5NCHW(self): if test.is_gpu_available(cuda_only=True): - with self.test_session(use_gpu=True) as sess: + with self.session(use_gpu=True) as sess: num_filters = 1 input_size = [1, 1, 3, 2] expected_size = [1, num_filters, 6, 10] @@ -2056,7 +2056,7 @@ class BatchNormTest(test.TestCase): channels = 3 np.random.seed(1) use_gpu = fused - with self.test_session(use_gpu=use_gpu) as sess: + with self.session(use_gpu=use_gpu) as sess: if data_format == 'NHWC': image_shape = (batch_size, height, width, channels) axis = (0, 1, 2) @@ -2140,7 +2140,7 @@ class BatchNormTest(test.TestCase): channels = 3 np.random.seed(1) use_gpu = fused - with self.test_session(use_gpu=use_gpu) as sess: + with self.session(use_gpu=use_gpu) as sess: if data_format == 'NHWC': image_shape = (batch_size, height, width, channels) axis = (0, 1, 2) @@ -2344,7 +2344,7 @@ class BatchNormTest(test.TestCase): np.random.seed(1) use_gpu = fused np.random.seed(1) - with self.test_session(use_gpu=use_gpu) as sess: + with self.session(use_gpu=use_gpu) as sess: if data_format == 'NHWC': image_shape = (batch_size, height, width, channels) axis = (0, 1, 2) @@ -2491,7 +2491,7 @@ class BatchNormTest(test.TestCase): channels = 3 np.random.seed(1) use_gpu = fused - with self.test_session(use_gpu=use_gpu) as sess: + with self.session(use_gpu=use_gpu) as sess: if data_format == 'NHWC': image_shape = (batch_size, height, width, channels) axis = (0, 1, 2) @@ -2576,7 +2576,7 @@ class BatchNormTest(test.TestCase): channels = 32 np.random.seed(1) use_gpu = fused - with self.test_session(use_gpu=use_gpu) as sess: + with self.session(use_gpu=use_gpu) as sess: if data_format == 'NHWC': image_shape = (batch_size, height, width, channels) axis = (0, 1, 2) @@ -2674,7 +2674,7 @@ class BatchNormTest(test.TestCase): def _runBatchNormalizationWithFormat(self, shape, data_format, is_training): channels = shape[-1] - with self.test_session(use_gpu=True) as sess: + with self.session(use_gpu=True) as sess: images = np.arange(np.product(shape), dtype=np.float32).reshape(shape) beta = init_ops.constant_initializer( np.arange(2, channels + 2, dtype=np.float32)) @@ -2776,7 +2776,7 @@ class BatchNormTest(test.TestCase): 'moving_variance': variance, }, data_format='NCHW') - with self.test_session(use_gpu=True) as sess: + with self.session(use_gpu=True) as sess: sess.run(variables_lib.global_variables_initializer()) return sess.run(output) diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index c1de42782efb3497660affb3ef7162457977c150..3efceab3375d3a1801c87122c98920cc523a3aca 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -1433,13 +1433,12 @@ class Estimator(BaseEstimator): 'must specify no transforms.') untransformed_tags = graph_rewrite_specs[0].tags - # TODO(soergel): switch to main_op or otherwise update when dust settles builder.add_meta_graph_and_variables( session, untransformed_tags, signature_def_map=signature_def_map, assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS), - legacy_init_op=init_op, + main_op=init_op, strip_default_attrs=strip_default_attrs) # pylint: disable=protected-access diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index b98adf862bf1514b43d237196cb2de531a909479..94ff1dd5b01b5a24d1deb7053553b9df48709c7c 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -22,6 +22,7 @@ import collections from six.moves import range from tensorflow.contrib.linear_optimizer.python.ops.sharded_mutable_dense_hashtable import ShardedMutableDenseHashTable +from tensorflow.python.compat import compat from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -151,7 +152,8 @@ class SdcaModel(object): default_value=[0.0, 0.0, 0.0, 0.0], # SdcaFprint never returns 0 or 1 for the low64 bits, so this a safe # empty_key (that will never collide with actual payloads). - empty_key=[0, 0]) + empty_key=[0, 0], + deleted_key=[1, 1]) summary.scalar('approximate_duality_gap', self.approximate_duality_gap()) summary.scalar('examples_seen', self._hashtable.size()) @@ -202,7 +204,7 @@ class SdcaModel(object): with ops.colocate_with(v): # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109 # is fixed. - slot_var = var_ops.Variable( + slot_var = var_ops.VariableV1( initial_value=array_ops.zeros_like(v.initialized_value(), dtypes.float32), name=v.op.name + '_unshrinked/SDCAOptimizer') @@ -214,7 +216,7 @@ class SdcaModel(object): # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109 is # fixed. self._slots['unshrinked_' + name].append( - var_ops.Variable( + var_ops.VariableV1( array_ops.zeros_like(var.initialized_value(), dtypes.float32), name=var.op.name + '_unshrinked/SDCAOptimizer')) @@ -485,24 +487,44 @@ class SdcaModel(object): sparse_weights.append(batch_gathered_weights) # pylint: disable=protected-access - esu, sfw, dfw = gen_sdca_ops.sdca_optimizer( - sparse_example_indices, - sparse_feature_indices, - sparse_features_values, - self._convert_n_to_tensor(self._examples['dense_features']), - internal_convert_to_tensor(self._examples['example_weights']), - internal_convert_to_tensor(self._examples['example_labels']), - sparse_indices, - sparse_weights, - self._convert_n_to_tensor(self._slots[ - 'unshrinked_dense_features_weights']), - example_state_data, - loss_type=self._options['loss_type'], - l1=self._options['symmetric_l1_regularization'], - l2=self._symmetric_l2_regularization(), - num_loss_partitions=self._num_loss_partitions(), - num_inner_iterations=1, - adaptative=self._adaptive()) + if compat.forward_compatible(year=2018, month=10, day=30): + esu, sfw, dfw = gen_sdca_ops.sdca_optimizer_v2( + sparse_example_indices, + sparse_feature_indices, + sparse_features_values, + self._convert_n_to_tensor(self._examples['dense_features']), + internal_convert_to_tensor(self._examples['example_weights']), + internal_convert_to_tensor(self._examples['example_labels']), + sparse_indices, + sparse_weights, + self._convert_n_to_tensor(self._slots[ + 'unshrinked_dense_features_weights']), + example_state_data, + loss_type=self._options['loss_type'], + l1=self._options['symmetric_l1_regularization'], + l2=self._symmetric_l2_regularization(), + num_loss_partitions=self._num_loss_partitions(), + num_inner_iterations=1, + adaptive=self._adaptive()) + else: + esu, sfw, dfw = gen_sdca_ops.sdca_optimizer( + sparse_example_indices, + sparse_feature_indices, + sparse_features_values, + self._convert_n_to_tensor(self._examples['dense_features']), + internal_convert_to_tensor(self._examples['example_weights']), + internal_convert_to_tensor(self._examples['example_labels']), + sparse_indices, + sparse_weights, + self._convert_n_to_tensor(self._slots[ + 'unshrinked_dense_features_weights']), + example_state_data, + loss_type=self._options['loss_type'], + l1=self._options['symmetric_l1_regularization'], + l2=self._symmetric_l2_regularization(), + num_loss_partitions=self._num_loss_partitions(), + num_inner_iterations=1, + adaptative=self._adaptive()) # pylint: enable=protected-access with ops.control_dependencies([esu]): diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py index 5015fb0848107950dd27eb81431dd308f22858bc..44a869f7c2745c594b6a4ea69a2a9e6f1b4f780a 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py @@ -48,6 +48,7 @@ class ShardedMutableDenseHashTable(lookup.LookupInterface): value_dtype, default_value, empty_key, + deleted_key, num_shards=1, checkpoint=True, name='ShardedMutableHashTable'): @@ -62,6 +63,7 @@ class ShardedMutableDenseHashTable(lookup.LookupInterface): value_dtype=value_dtype, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, checkpoint=checkpoint, name='%s-%d-of-%d' % (name, i + 1, num_shards))) self._table_shards = table_shards diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py index 553b116a3b3d76423d4700691fb6912101bebca4..2b56d0fa3a8b8564b7c73a62bd99cc900d6f5c54 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py @@ -33,6 +33,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): with self.cached_session(): default_val = -1 empty_key = 0 + deleted_key = -1 keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) table = ShardedMutableDenseHashTable( @@ -40,6 +41,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): dtypes.int64, default_val, empty_key, + deleted_key, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) @@ -56,6 +58,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): with self.cached_session(): default_val = [-0.1, 0.2] empty_key = [0, 1] + deleted_key = [1, 0] keys = constant_op.constant([[11, 12], [13, 14], [15, 16]], dtypes.int64) values = constant_op.constant([[0.5, 0.6], [1.5, 1.6], [2.5, 2.6]], @@ -65,6 +68,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): dtypes.float32, default_val, empty_key, + deleted_key, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) @@ -81,6 +85,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): def testExportSharded(self): with self.cached_session(): empty_key = -2 + deleted_key = -3 default_val = -1 num_shards = 2 keys = constant_op.constant([10, 11, 12], dtypes.int64) @@ -90,6 +95,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): dtypes.int64, default_val, empty_key, + deleted_key, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index f3ebe3b2454a10d6a07e84e49e3f0415ddcf7c4d..787a85644c35c807df84f74cbce06f80fd0b004d 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -4,6 +4,7 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 +load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") exports_files(glob([ @@ -165,10 +166,6 @@ cc_library( "stderr_reporter.h", ], copts = tflite_copts(), - defines = select({ - ":with_tflite_flex": ["TFLITE_FLEX"], - "//conditions:default": [], - }), linkopts = [ ] + select({ "//tensorflow:android": [ @@ -276,6 +273,7 @@ cc_test( "testdata/0_subgraphs.bin", "testdata/2_subgraphs.bin", "testdata/empty_model.bin", + "testdata/multi_add_flex.bin", "testdata/test_model.bin", "testdata/test_model_broken.bin", ], @@ -283,6 +281,26 @@ cc_test( ":framework", "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/core/api", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +# Test model framework with the flex library linked into the target. +tf_cc_test( + name = "model_flex_test", + size = "small", + srcs = ["model_flex_test.cc"], + data = [ + "testdata/multi_add_flex.bin", + ], + tags = ["no_windows"], # TODO(b/116667551): No weak symbols with MSVC. + deps = [ + ":framework", + "//tensorflow/contrib/lite/core/api", + "//tensorflow/contrib/lite/delegates/flex:delegate", + "//tensorflow/contrib/lite/kernels:builtin_ops", "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 7ef26de69f2699e3d9f55a15737b96a3505cf6eb..e62c192dfcc8d38cd168b0efdc14da74967eb939 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -212,7 +212,8 @@ def json_to_tflite(name, src, out): # This is the master list of generated examples that will be made into tests. A # function called make_XXX_tests() must also appear in generate_examples.py. -# Disable a test by commenting it out. If you do, add a link to a bug or issue. +# Disable a test by adding it to the blacklists specified in +# generated_test_models_failing(). def generated_test_models(): return [ "add", @@ -291,12 +292,31 @@ def generated_test_models(): "tile", "topk", "transpose", - #"transpose_conv", # disabled due to b/111213074 + "transpose_conv", "unpack", "where", "zeros_like", ] +# List of models that fail generated tests for the conversion mode. +# If you have to disable a test, please add here with a link to the appropriate +# bug or issue. +def generated_test_models_failing(conversion_mode): + if not conversion_mode: + return [ + "transpose_conv", # disabled due to b/111213074 + ] + + if conversion_mode == "toco-flex": + # TODO(b/117328698): Fix and enable the known flex failures. + return [ + "lstm", + "split", + "unpack", + ] + + return [] + def generated_test_conversion_modes(): """Returns a list of conversion modes.""" @@ -307,16 +327,28 @@ def generated_test_models_all(): """Generates a list of all tests with the different converters. Returns: - List of tuples representing (conversion mode, name of test). + List of tuples representing: + (conversion mode, name of test, test tags, test args). """ conversion_modes = generated_test_conversion_modes() tests = generated_test_models() options = [] for conversion_mode in conversion_modes: + failing_tests = generated_test_models_failing(conversion_mode) for test in tests: + tags = [] + args = [] + if test in failing_tests: + tags.append("notap") + tags.append("manual") if conversion_mode: test += "_%s" % conversion_mode - options.append((conversion_mode, test)) + + # Flex conversion shouldn't suffer from the same conversion bugs + # listed for the default TFLite kernel backend. + if conversion_mode == "toco-flex": + args.append("--ignore_known_bugs=false") + options.append((conversion_mode, test, tags, args)) return options def gen_zip_test(name, test_name, conversion_mode, **kwargs): @@ -336,9 +368,6 @@ def gen_zip_test(name, test_name, conversion_mode, **kwargs): # if conversion_mode == "pb2lite": # toco = "//tensorflow/contrib/lite/experimental/pb2lite:pb2lite" flags = "--ignore_toco_errors --run_with_flex" - kwargs["tags"].append("skip_already_failing") - kwargs["tags"].append("no_oss") - kwargs["tags"].append("notap") gen_zipped_test_file( name = "zip_%s" % test_name, @@ -392,14 +421,14 @@ def gen_selected_ops(name, model): tools = [tool], ) -def gen_full_model_test(conversion_modes, models, data, test_suite_tag): +def gen_full_model_test(conversion_modes, models, data, tags): """Generates Python test targets for testing TFLite models. Args: conversion_modes: List of conversion modes to test the models on. models: List of models to test. data: List of BUILD targets linking the data. - test_suite_tag: Tag identifying the model test suite. + tags: Any additional tags including the test_suite tag. """ options = [ (conversion_mode, model) @@ -422,9 +451,11 @@ def gen_full_model_test(conversion_modes, models, data, test_suite_tag): "no_oss", "no_windows", "notap", - ] + [test_suite_tag], + # TODO(nupurgarg): Remove manual tag when this test is running without the BUILD flag. + "manual", + ] + tags, deps = [ - "//tensorflow/contrib/lite/testing:model_coverage_lib", + "//tensorflow/contrib/lite/testing/model_coverage:model_coverage_lib", "//tensorflow/contrib/lite/python:lite", "//tensorflow/python:client_testlib", ], diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index 7809d114e2f72991be98bfa760f1f240864b5aa6..eb26c2dbdbb41ce17cab362dae14ef67f760ce27 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -120,6 +120,8 @@ typedef enum { kTfLiteBuiltinSquare = 92, kTfLiteBuiltinZerosLike = 93, kTfLiteBuiltinFill = 94, + kTfLiteBuiltinFloorMod = 95, + kTfLiteBuiltinRange = 96, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/c/builtin_op_data.h b/tensorflow/contrib/lite/c/builtin_op_data.h index be9d551ee4d5e94284e1cdb38f464327361d73b7..5a5f3ad61c1c9753cffb34e8f7cc0e005a8c971f 100644 --- a/tensorflow/contrib/lite/c/builtin_op_data.h +++ b/tensorflow/contrib/lite/c/builtin_op_data.h @@ -99,6 +99,12 @@ typedef struct { TfLiteFusedActivation activation; } TfLiteSequenceRNNParams; +typedef struct { + bool time_major; + TfLiteFusedActivation activation; + bool merge_outputs; +} TfLiteBidirectionalSequenceRNNParams; + typedef enum { kTfLiteFullyConnectedWeightsFormatDefault = 0, kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1, @@ -180,6 +186,26 @@ typedef struct { TfLiteLSTMKernelType kernel_type; } TfLiteLSTMParams; +typedef struct { + // Parameters needed for the underlying LSTM. + TfLiteFusedActivation activation; + float cell_clip; + float proj_clip; + + // If set to true then the first dimension is time, otherwise batch. + bool time_major; +} TfLiteUnidirectionalSequenceLSTMParams; + +typedef struct { + // Parameters for the LSTM kernel. + TfLiteFusedActivation activation; + float cell_clip; + float proj_clip; + + // If true, store the outputs of both directions in the first output. + bool merge_outputs; +} TfLiteBidirectionalSequenceLSTMParams; + typedef struct { bool align_corners; } TfLiteResizeBilinearParams; diff --git a/tensorflow/contrib/lite/c/builtin_op_data_test.cc b/tensorflow/contrib/lite/c/builtin_op_data_test.cc index 4d0ba75e68367c9a0a7a7c9c3ac1ea14a875c201..ba458b4252c53ebc91adcd0afbd16f783037dd42 100644 --- a/tensorflow/contrib/lite/c/builtin_op_data_test.cc +++ b/tensorflow/contrib/lite/c/builtin_op_data_test.cc @@ -73,6 +73,8 @@ TEST(IntArray, CanCompileStructs) { TfLiteFakeQuantParams fake_quant_params; TfLitePackParams pack_params; TfLiteOneHotParams one_hot_params; + TfLiteBidirectionalSequenceRNNParams bidi_sequence_rnn_params; + TfLiteBidirectionalSequenceLSTMParams bidi_sequence_lstm_params; } } // namespace tflite diff --git a/tensorflow/contrib/lite/c/c_api_internal.c b/tensorflow/contrib/lite/c/c_api_internal.c index 8a0c177b1948df9b98e68f6cc6f44628ea8407a3..8be37945ca2a5ddf3c8cedc5a3ae5e34da8a4b9b 100644 --- a/tensorflow/contrib/lite/c/c_api_internal.c +++ b/tensorflow/contrib/lite/c/c_api_internal.c @@ -28,10 +28,15 @@ int TfLiteIntArrayGetSizeInBytes(int size) { int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b) { if (a == b) return 1; if (a == NULL || b == NULL) return 0; - if (a->size != b->size) return 0; + return TfLiteIntArrayEqualsArray(a, b->size, b->data); +} + +int TfLiteIntArrayEqualsArray(TfLiteIntArray* a, int b_size, int b_data[]) { + if (a == NULL) return (b_size == 0); + if (a->size != b_size) return 0; int i = 0; for (; i < a->size; i++) - if (a->data[i] != b->data[i]) return 0; + if (a->data[i] != b_data[i]) return 0; return 1; } diff --git a/tensorflow/contrib/lite/c/c_api_internal.h b/tensorflow/contrib/lite/c/c_api_internal.h index ee3dff6792a33a575e75fe7a1ef3dc7985be9c1d..fdc9ff634a19d348ab2dfae60d94722619dfec06 100644 --- a/tensorflow/contrib/lite/c/c_api_internal.h +++ b/tensorflow/contrib/lite/c/c_api_internal.h @@ -88,9 +88,12 @@ int TfLiteIntArrayGetSizeInBytes(int size); // This returns a pointer, that you must free using TfLiteIntArrayFree(). TfLiteIntArray* TfLiteIntArrayCreate(int size); -// Check if two tensors are equal. Returns 1 if they are equal, 0 otherwise. +// Check if two intarrays are equal. Returns 1 if they are equal, 0 otherwise. int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b); +// Check if an intarray equals an array. Returns 1 if equals, 0 otherwise. +int TfLiteIntArrayEqualsArray(TfLiteIntArray* a, int b_size, int b_data[]); + // Create a copy of an array passed as `src`. // You are expected to free memory with TfLiteIntArrayFree TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src); diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc index e6900e0950305d5d482814c1f617a42231998cc0..fe56c4ebf9238b2b6a94371620b06974b9c42e7f 100644 --- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc @@ -224,10 +224,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: { - TfLiteSequenceRNNParams* params = - allocator->AllocatePOD(); + auto params = allocator->AllocatePOD(); if (auto* sequence_rnn_params = op->builtin_options_as_SequenceRNNOptions()) { params->activation = @@ -237,6 +235,19 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: { + auto params = + allocator->AllocatePOD(); + if (auto* bidi_sequence_rnn_params = + op->builtin_options_as_BidirectionalSequenceRNNOptions()) { + params->activation = parse_activation( + bidi_sequence_rnn_params->fused_activation_function()); + params->time_major = bidi_sequence_rnn_params->time_major(); + params->merge_outputs = bidi_sequence_rnn_params->merge_outputs(); + } + *builtin_data = reinterpret_cast(params); + break; + } case BuiltinOperator_RNN: { TfLiteRNNParams* params = allocator->AllocatePOD(); if (auto* rnn_params = op->builtin_options_as_RNNOptions()) { @@ -360,10 +371,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: - case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: case BuiltinOperator_LSTM: { - TfLiteLSTMParams* params = allocator->AllocatePOD(); + auto params = allocator->AllocatePOD(); if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) { params->activation = parse_activation(lstm_params->fused_activation_function()); @@ -381,6 +390,34 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: { + auto* params = + allocator->AllocatePOD(); + if (auto* seq_lstm_params = + op->builtin_options_as_UnidirectionalSequenceLSTMOptions()) { + params->activation = + parse_activation(seq_lstm_params->fused_activation_function()); + params->cell_clip = seq_lstm_params->cell_clip(); + params->proj_clip = seq_lstm_params->proj_clip(); + params->time_major = seq_lstm_params->time_major(); + } + *builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: { + auto params = + allocator->AllocatePOD(); + if (auto* bidi_lstm_params = + op->builtin_options_as_BidirectionalSequenceLSTMOptions()) { + params->activation = + parse_activation(bidi_lstm_params->fused_activation_function()); + params->cell_clip = bidi_lstm_params->cell_clip(); + params->proj_clip = bidi_lstm_params->proj_clip(); + params->merge_outputs = bidi_lstm_params->merge_outputs(); + } + *builtin_data = reinterpret_cast(params); + break; + } case BuiltinOperator_RESIZE_BILINEAR: { auto* params = allocator->AllocatePOD(); if (auto* schema_params = @@ -614,6 +651,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_SQUARE: case BuiltinOperator_ZEROS_LIKE: case BuiltinOperator_FILL: + case BuiltinOperator_FLOOR_MOD: + case BuiltinOperator_RANGE: break; } return kTfLiteOk; diff --git a/tensorflow/contrib/lite/delegates/flex/BUILD b/tensorflow/contrib/lite/delegates/flex/BUILD index bf5d91899ca63142f69401229b9e06b27b6c2b0b..2f866eaecb801695d800565e195f959d55a88201 100644 --- a/tensorflow/contrib/lite/delegates/flex/BUILD +++ b/tensorflow/contrib/lite/delegates/flex/BUILD @@ -2,7 +2,7 @@ # This is a TF Lite delegate that is powered by TensorFlow's Eager. # package(default_visibility = [ - "//visibility:public", + "//visibility:private", ]) licenses(["notice"]) # Apache 2.0 @@ -20,7 +20,7 @@ cc_library( "//tensorflow/contrib/lite:kernel_api", ] + select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite_no_runtime", + "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ "//tensorflow/core:framework", @@ -42,14 +42,38 @@ tf_cc_test( ], ) +# Delegate implementation that pulls in the standard set of TensorFlow ops and +# kernels. cc_library( name = "delegate", + hdrs = [ + "delegate.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":delegate_only_runtime", + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:tensorflow", + ], + }), + alwayslink = 1, +) + +# Delegate implementation that does *not* pull in the standard set of TensorFlow +# ops and kernels. +cc_library( + name = "delegate_only_runtime", srcs = [ "delegate.cc", ], hdrs = [ "delegate.h", ], + visibility = ["//visibility:public"], deps = [ ":buffer_map", ":delegate_data", @@ -60,12 +84,13 @@ cc_library( "//tensorflow/contrib/lite:util", ] + select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite_no_runtime", + "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ "//tensorflow/core:lib", ], }), + alwayslink = 1, ) tf_cc_test( @@ -132,12 +157,12 @@ cc_library( # set of core TensorFlow kernels. We may want to revisit this dependency # to allow selective registration via build targets. "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:framework", - "//tensorflow/core:tensorflow", ], }), ) @@ -151,7 +176,14 @@ tf_cc_test( ":kernel", ":test_util", "@com_google_googletest//:gtest", - ], + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:tensorflow", + ], + }), ) cc_library( @@ -178,7 +210,7 @@ cc_library( "//tensorflow/contrib/lite:kernel_api", ] + select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite_no_runtime", + "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ "//tensorflow/core:lib", diff --git a/tensorflow/contrib/lite/delegates/flex/delegate.cc b/tensorflow/contrib/lite/delegates/flex/delegate.cc index ba065a8ff527e1fa3355129dddab91a5df89fe7b..c72b0cf51383897ce3afec0c39ed6bfe178d88c1 100644 --- a/tensorflow/contrib/lite/delegates/flex/delegate.cc +++ b/tensorflow/contrib/lite/delegates/flex/delegate.cc @@ -83,6 +83,15 @@ TfLiteStatus CopyFromBufferHandle(TfLiteContext* context, } // namespace delegate } // namespace flex +// Corresponding weak declaration found in lite/model.cc. +std::unique_ptr +AcquireFlexDelegate() { + return std::unique_ptr( + tflite::FlexDelegate::Create().release(), [](TfLiteDelegate* delegate) { + delete reinterpret_cast(delegate); + }); +} + std::unique_ptr FlexDelegate::Create() { std::unique_ptr delegate_data; if (!flex::DelegateData::Create(&delegate_data).ok()) { diff --git a/tensorflow/contrib/lite/experimental/examples/lstm/BUILD b/tensorflow/contrib/lite/experimental/examples/lstm/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..2125f218ca877f94ec9f4d98928b6a1c8f2576eb --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/lstm/BUILD @@ -0,0 +1,40 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "tflite_lstm", + srcs = ["tflite_lstm.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/lite/python:lite", + "//tensorflow/python:framework", + "@six_archive//:six", + ], +) + +py_test( + name = "unidirectional_sequence_lstm_test", + size = "large", + srcs = ["unidirectional_sequence_lstm_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_oss", + "no_pip", + ], + deps = [ + ":tflite_lstm", + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/lite/python:lite", + "//tensorflow/examples/tutorials/mnist:input_data", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform", + "//tensorflow/python/tools:optimize_for_inference", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) diff --git a/tensorflow/contrib/lite/experimental/examples/lstm/tflite_lstm.py b/tensorflow/contrib/lite/experimental/examples/lstm/tflite_lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..2357743266f7082a5a003153718de08c83174ea5 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/lstm/tflite_lstm.py @@ -0,0 +1,396 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TfLite LSTMCell wrapper. + +TODO(renjieliu): Find a better home for this one. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import tensorflow as tf + +from tensorflow.contrib.lite.python import lite +from tensorflow.python.keras import activations +from tensorflow.python.keras import initializers +from tensorflow.python.layers import base as base_layer +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import rnn_cell_impl +from tensorflow.python.platform import tf_logging as logging + + +class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): + """Long short-term memory unit (LSTM) recurrent network cell. + + This is used only for TfLite, it provides hints and it also makes the + variables in the desired for the tflite ops (transposed and seaparated). + + The default non-peephole implementation is based on: + + https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf + + Felix Gers, Jurgen Schmidhuber, and Fred Cummins. + "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999. + + The peephole implementation is based on: + + https://research.google.com/pubs/archive/43905.pdf + + Hasim Sak, Andrew Senior, and Francoise Beaufays. + "Long short-term memory recurrent neural network architectures for + large scale acoustic modeling." INTERSPEECH, 2014. + + The class uses optional peep-hole connections, optional cell clipping, and + an optional projection layer. + + Note that this cell is not optimized for performance. Please use + `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or + `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for + better performance on CPU. + """ + + def __init__(self, + num_units, + use_peepholes=False, + cell_clip=None, + initializer=None, + num_proj=None, + proj_clip=None, + num_unit_shards=None, + num_proj_shards=None, + forget_bias=1.0, + state_is_tuple=True, + activation=None, + reuse=None, + name=None, + dtype=None): + """Initialize the parameters for an LSTM cell. + + Args: + num_units: int, The number of units in the LSTM cell. + use_peepholes: bool, set True to enable diagonal/peephole connections. + cell_clip: (optional) A float value, if provided the cell state is clipped + by this value prior to the cell output activation. + initializer: (optional) The initializer to use for the weight and + projection matrices. + num_proj: (optional) int, The output dimensionality for the projection + matrices. If None, no projection is performed. + proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is + provided, then the projected values are clipped elementwise to within + `[-proj_clip, proj_clip]`. + num_unit_shards: Deprecated, will be removed by Jan. 2017. Use a + variable_scope partitioner instead. + num_proj_shards: Deprecated, will be removed by Jan. 2017. Use a + variable_scope partitioner instead. + forget_bias: Biases of the forget gate are initialized by default to 1 in + order to reduce the scale of forgetting at the beginning of the + training. Must set it manually to `0.0` when restoring from CudnnLSTM + trained checkpoints. + state_is_tuple: If True, accepted and returned states are 2-tuples of the + `c_state` and `m_state`. If False, they are concatenated along the + column axis. This latter behavior will soon be deprecated. + activation: Activation function of the inner states. Default: `tanh`. + reuse: (optional) Python boolean describing whether to reuse variables in + an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + name: String, the name of the layer. Layers with the same name will share + weights, but to avoid mistakes we require reuse=True in such cases. + dtype: Default dtype of the layer (default of `None` means use the type of + the first input). Required when `build` is called before `call`. When + restoring from CudnnLSTM-trained checkpoints, use + `CudnnCompatibleLSTMCell` instead. + """ + super(TFLiteLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) + # TODO(raziel): decide if we want to just support tuples (yes please!). + if not state_is_tuple: + logging.warn( + "%s: Using a concatenated state is slower and will soon be " + "deprecated. Use state_is_tuple=True.", self) + if num_unit_shards is not None or num_proj_shards is not None: + logging.warn( + "%s: The num_unit_shards and proj_unit_shards parameters are " + "deprecated and will be removed in Jan 2017. " + "Use a variable scope with a partitioner instead.", self) + + # Inputs must be 2-dimensional. + # TODO(raziel): layers stuff -- chop if un-layerizing Op. + self.input_spec = base_layer.InputSpec(ndim=2) + + self._tflite_wrapper = lite.OpHint("UnidirectionalSequenceLstm") + + self._num_units = num_units + self._use_peepholes = use_peepholes + self._cell_clip = cell_clip + self._initializer = initializer + self._num_proj = num_proj + self._proj_clip = proj_clip + self._num_unit_shards = num_unit_shards + self._num_proj_shards = num_proj_shards + self._forget_bias = forget_bias + self._state_is_tuple = state_is_tuple + self._activation = activation or math_ops.tanh + + self._output_size = num_proj if num_proj else num_units + self._state_size = ( + tf.nn.rnn_cell.LSTMStateTuple(num_units, self._output_size) + if state_is_tuple else num_units + self._output_size) + + @property + def state_size(self): + return self._state_size + + @property + def output_size(self): + return self._output_size + + def build(self, inputs_shape): + """Build TfLite LSTM cell graph. + + Args: + inputs_shape: The inputs_shape must be known, and is [batch_size, + input_size] shape. + + Raises: + ValueError: if the inputs_shape is invalid. + """ + if len(inputs_shape) != 2 or inputs_shape[1].value is None: + raise ValueError("Invalid inputs_shape, saw shape: %s" % inputs_shape) + + input_depth = inputs_shape[1].value + maybe_partitioner = ( + partitioned_variables.fixed_size_partitioner(self._num_unit_shards) + if self._num_unit_shards is not None else None) + input_weight_shape = [self._num_units, input_depth] + cell_weight_shape = [self._num_units, self._output_size] + bias_shape = [self._num_units] + + def add_variable_wrapped(name, shape, initializer, index, partitioner): + var = self.add_variable( + name, shape=shape, initializer=initializer, partitioner=partitioner) + return self._tflite_wrapper.add_input( + var, name="name", index_override=index) + + weight_initializer = self._initializer + if self.dtype is None: + bias_initializer = init_ops.zeros_initializer + else: + bias_initializer = init_ops.zeros_initializer(dtype=self.dtype) + + self.input_to_input_w = add_variable_wrapped( + "input_to_input_w", input_weight_shape, weight_initializer, 1, + maybe_partitioner) + self.input_to_forget_w = add_variable_wrapped( + "input_to_forget_w", input_weight_shape, weight_initializer, 2, + maybe_partitioner) + self.input_to_cell_w = add_variable_wrapped( + "input_to_cell_w", input_weight_shape, weight_initializer, 3, + maybe_partitioner) + self.input_to_output_w = add_variable_wrapped( + "input_to_output_w", input_weight_shape, weight_initializer, 4, + maybe_partitioner) + self.cell_to_input_w = add_variable_wrapped( + "cell_to_input_w", cell_weight_shape, weight_initializer, 5, + maybe_partitioner) + self.cell_to_forget_w = add_variable_wrapped( + "cell_to_forget_w", cell_weight_shape, weight_initializer, 6, + maybe_partitioner) + self.cell_to_cell_w = add_variable_wrapped( + "cell_to_cell_w", cell_weight_shape, weight_initializer, 7, + maybe_partitioner) + self.cell_to_output_w = add_variable_wrapped( + "cell_to_output_w", cell_weight_shape, weight_initializer, 8, + maybe_partitioner) + + self.input_bias = add_variable_wrapped( + "input_bias", bias_shape, bias_initializer, 12, maybe_partitioner) + self.forget_bias = add_variable_wrapped( + "forget_bias", bias_shape, bias_initializer, 13, maybe_partitioner) + self.cell_bias = add_variable_wrapped( + "cell_bias", bias_shape, bias_initializer, 14, maybe_partitioner) + self.output_bias = add_variable_wrapped( + "output_bias", bias_shape, bias_initializer, 15, maybe_partitioner) + + # index 9, 10, 11. + # f stands for forget, i stands for input and o stands for output. + if self._use_peepholes: + self._w_f_diag = add_variable_wrapped("w_f_diag", [self._num_units], + self._initializer, 9, + maybe_partitioner) + self._w_i_diag = add_variable_wrapped("w_i_diag", [self._num_units], + self._initializer, 10, + maybe_partitioner) + self._w_o_diag = add_variable_wrapped("w_o_diag", [self._num_units], + self._initializer, 11, + maybe_partitioner) + + # index 16 for proj kernel. + if self._num_proj is not None: + maybe_proj_partitioner = ( + partitioned_variables.fixed_size_partitioner(self._num_proj_shards) + if self._num_proj_shards is not None else None) + self._proj_kernel = add_variable_wrapped( + "projection/kernel", [self._num_proj, self._num_units], + self._initializer, + 16, + partitioner=maybe_proj_partitioner) + + self.built = True + + def call(self, inputs, state): + """Run one step of LSTM. + + Args: + inputs: input Tensor, 2D, `[batch, num_units]`. + state: if `state_is_tuple` is False, this must be a state Tensor, `2-D, + [batch, state_size]`. If `state_is_tuple` is True, this must be a tuple + of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. + + Returns: + A tuple containing: + + - A `2-D, [batch, output_dim]`, Tensor representing the output of the + LSTM after reading `inputs` when previous state was `state`. + Here output_dim is: + num_proj if num_proj was set, + num_units otherwise. + - Tensor(s) representing the new state of LSTM after reading `inputs` when + the previous state was `state`. Same type and shape(s) as `state`. + + Raises: + ValueError: If input size cannot be inferred from inputs via + static shape inference. + """ + inputs = self._tflite_wrapper.add_input( + inputs, tag="input", name="input", aggregate="stack", index_override=0) + + # Make sure inputs and bias_initializer has the same type. + assert inputs.dtype == self.input_to_input_w.dtype + + num_proj = self._num_units if self._num_proj is None else self._num_proj + sigmoid = math_ops.sigmoid + + if self._state_is_tuple: + (c_prev, m_prev) = state + else: + c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) + m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) + + # Note: For TfLite, cell_state is at index 19 while activation state at + # index 18. + c_prev = self._tflite_wrapper.add_input( + c_prev, + tag="c_prev", + name="c_prev", + aggregate="first", + index_override=19) + m_prev = self._tflite_wrapper.add_input( + m_prev, + tag="m_prev", + name="m_prev", + aggregate="first", + index_override=18) + + input_size = inputs.get_shape().with_rank(2)[1] + if input_size.value is None: + raise ValueError("Could not infer input size from inputs.get_shape()[-1]") + + inputs_and_m_prev = array_ops.concat([inputs, m_prev], axis=1) + + # i stands for input gate. + # f stands for forget gate activation. + # o outputs. + # j output of LSTM unit. + # c is the final state. + # m is the output. + i = nn_ops.bias_add( + tf.matmul( + inputs_and_m_prev, + tf.concat([self.input_to_input_w, self.cell_to_input_w], axis=1), + transpose_b=True), self.input_bias) + f = nn_ops.bias_add( + tf.matmul( + inputs_and_m_prev, + tf.concat([self.input_to_forget_w, self.cell_to_forget_w], axis=1), + transpose_b=True), self.forget_bias) + o = nn_ops.bias_add( + tf.matmul( + inputs_and_m_prev, + tf.concat([self.input_to_output_w, self.cell_to_output_w], axis=1), + transpose_b=True), self.output_bias) + j = nn_ops.bias_add( + tf.matmul( + inputs_and_m_prev, + tf.concat([self.input_to_cell_w, self.cell_to_cell_w], axis=1), + transpose_b=True), self.cell_bias) + + # Diagonal connections + if self._use_peepholes: + c = ( + sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev + + sigmoid(i + self._w_i_diag * c_prev) * self._activation(j)) + else: + c = ( + sigmoid(f + self._forget_bias) * c_prev + + sigmoid(i) * self._activation(j)) + + if self._cell_clip is not None: + # pylint: disable=invalid-unary-operand-type + c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) + # pylint: enable=invalid-unary-operand-type + if self._use_peepholes: + m = sigmoid(o + self._w_o_diag * c) * self._activation(c) + else: + m = sigmoid(o) * self._activation(c) + + if self._num_proj is not None: + transposed_proj_kernel = tf.transpose(self._proj_kernel) + m = math_ops.matmul(m, transposed_proj_kernel) + + if self._proj_clip is not None: + # pylint: disable=invalid-unary-operand-type + m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) + # pylint: enable=invalid-unary-operand-type + + c = self._tflite_wrapper.add_output( + c, tag="c", name="c", aggregate="last", index_override=1) + m = self._tflite_wrapper.add_output( + m, tag="m", name="m", index_override=2, aggregate="stack") + + new_state = ( + tf.nn.rnn_cell.LSTMStateTuple(c, m) + if self._state_is_tuple else array_ops.concat([c, m], 1)) + return m, new_state + + def get_config(self): + config = { + "num_units": self._num_units, + "use_peepholes": self._use_peepholes, + "cell_clip": self._cell_clip, + "initializer": initializers.serialize(self._initializer), + "num_proj": self._num_proj, + "proj_clip": self._proj_clip, + "num_unit_shards": self._num_unit_shards, + "num_proj_shards": self._num_proj_shards, + "forget_bias": self._forget_bias, + "state_is_tuple": self._state_is_tuple, + "activation": activations.serialize(self._activation), + "reuse": self._reuse, + } + base_config = super(TFLiteLSTMCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/tensorflow/contrib/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py b/tensorflow/contrib/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2ca977518cb11db5f7ed33afa25ead5c02221a95 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py @@ -0,0 +1,226 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import tempfile +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.lite.experimental.examples.lstm.tflite_lstm import TFLiteLSTMCell +from tensorflow.examples.tutorials.mnist import input_data +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test +from tensorflow.python.tools import optimize_for_inference_lib + +# Number of steps to train model. +TRAIN_STEPS = 1 + +CONFIG = tf.ConfigProto(device_count={"GPU": 0}) + + +class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase): + + def setUp(self): + tf.reset_default_graph() + # Import MNIST dataset + self.mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) + + # Define constants + # Unrolled through 28 time steps + self.time_steps = 28 + # Rows of 28 pixels + self.n_input = 28 + # Learning rate for Adam optimizer + self.learning_rate = 0.001 + # MNIST is meant to be classified in 10 classes(0-9). + self.n_classes = 10 + # Batch size + self.batch_size = 16 + # Lstm Units. + self.num_units = 64 + + def buildLstmLayer(self): + return tf.nn.rnn_cell.MultiRNNCell([ + TFLiteLSTMCell( + self.num_units, use_peepholes=True, forget_bias=0, name="rnn1"), + TFLiteLSTMCell(self.num_units, num_proj=64, forget_bias=0, name="rnn2"), + TFLiteLSTMCell( + self.num_units // 2, + use_peepholes=True, + num_proj=64, + forget_bias=0, + name="rnn3"), + TFLiteLSTMCell(self.num_units, forget_bias=0, name="rnn4") + ]) + + def buildModel(self, lstm_layer, is_dynamic_rnn, is_train): + # Weights and biases for output softmax layer. + out_weights = tf.Variable( + tf.random_normal([self.num_units, self.n_classes])) + out_bias = tf.Variable(tf.random_normal([self.n_classes])) + + # input image placeholder + x = tf.placeholder( + "float", [None, self.time_steps, self.n_input], name="INPUT_IMAGE") + + # For dynamic_rnn, train with dynamic_rnn and inference with static_rnn. + # x is shaped [batch_size,time_steps,num_inputs] + if is_dynamic_rnn: + if is_train: + lstm_input = x + outputs, _ = tf.nn.dynamic_rnn(lstm_layer, lstm_input, dtype="float32") + outputs = tf.unstack(outputs, axis=1) + else: + lstm_input = tf.unstack(x, self.time_steps, 1) + outputs, _ = tf.nn.static_rnn(lstm_layer, lstm_input, dtype="float32") + else: + lstm_input = tf.unstack(x, self.time_steps, 1) + outputs, _ = tf.nn.static_rnn(lstm_layer, lstm_input, dtype="float32") + + # Compute logits by multiplying outputs[-1] of shape [batch_size,num_units] + # by the softmax layer's out_weight of shape [num_units,n_classes] + # plus out_bias + prediction = tf.matmul(outputs[-1], out_weights) + out_bias + output_class = tf.nn.softmax(prediction, name="OUTPUT_CLASS") + + return x, prediction, output_class + + def trainModel(self, x, prediction, output_class, sess): + # input label placeholder + y = tf.placeholder("float", [None, self.n_classes]) + # Loss function + loss = tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=y)) + # Optimization + opt = tf.train.AdamOptimizer( + learning_rate=self.learning_rate).minimize(loss) + + # Initialize variables + init = tf.global_variables_initializer() + sess.run(init) + for _ in range(TRAIN_STEPS): + batch_x, batch_y = self.mnist.train.next_batch( + batch_size=self.batch_size, shuffle=False) + + batch_x = batch_x.reshape((self.batch_size, self.time_steps, + self.n_input)) + sess.run(opt, feed_dict={x: batch_x, y: batch_y}) + + def saveAndRestoreModel(self, lstm_layer, sess, saver, is_dynamic_rnn): + model_dir = tempfile.mkdtemp() + saver.save(sess, model_dir) + + # Reset the graph. + tf.reset_default_graph() + x, prediction, output_class = self.buildModel( + lstm_layer, is_dynamic_rnn, is_train=False) + + new_sess = tf.Session(config=CONFIG) + saver = tf.train.Saver() + saver.restore(new_sess, model_dir) + return x, prediction, output_class, new_sess + + def getInferenceResult(self, x, output_class, sess): + b1, _ = self.mnist.train.next_batch(batch_size=1) + sample_input = np.reshape(b1, (1, self.time_steps, self.n_input)) + + expected_output = sess.run(output_class, feed_dict={x: sample_input}) + frozen_graph = tf.graph_util.convert_variables_to_constants( + sess, sess.graph_def, [output_class.op.name]) + return sample_input, expected_output, frozen_graph + + def tfliteInvoke(self, graph, test_inputs, outputs): + tf.reset_default_graph() + # Turn the input into placeholder of shape 1 + tflite_input = tf.placeholder( + "float", [1, self.time_steps, self.n_input], name="INPUT_IMAGE_LITE") + tf.import_graph_def(graph, name="", input_map={"INPUT_IMAGE": tflite_input}) + with tf.Session() as sess: + curr = sess.graph_def + curr = tf.contrib.lite.convert_op_hints_to_stubs(graph_def=curr) + + curr = optimize_for_inference_lib.optimize_for_inference( + curr, ["INPUT_IMAGE_LITE"], ["OUTPUT_CLASS"], + [tf.float32.as_datatype_enum]) + + tflite = tf.contrib.lite.toco_convert( + curr, [tflite_input], [outputs], allow_custom_ops=False) + interpreter = tf.contrib.lite.Interpreter(model_content=tflite) + + try: + interpreter.allocate_tensors() + except ValueError: + assert False + + input_index = (interpreter.get_input_details()[0]["index"]) + interpreter.set_tensor(input_index, test_inputs) + interpreter.invoke() + output_index = (interpreter.get_output_details()[0]["index"]) + result = interpreter.get_tensor(output_index) + # Reset all variables so it will not pollute other inferences. + interpreter.reset_all_variables() + return result + + def testStaticRnnMultiRnnCell(self): + sess = tf.Session(config=CONFIG) + + x, prediction, output_class = self.buildModel( + self.buildLstmLayer(), is_dynamic_rnn=False, is_train=True) + self.trainModel(x, prediction, output_class, sess) + + saver = tf.train.Saver() + x, prediction, output_class, new_sess = self.saveAndRestoreModel( + self.buildLstmLayer(), sess, saver, is_dynamic_rnn=False) + + test_inputs, expected_output, frozen_graph = self.getInferenceResult( + x, output_class, new_sess) + + result = self.tfliteInvoke(frozen_graph, test_inputs, output_class) + self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-3)) + + def testDynamicRnnMultiRnnCell(self): + sess = tf.Session(config=CONFIG) + + x, prediction, output_class = self.buildModel( + self.buildLstmLayer(), is_dynamic_rnn=True, is_train=True) + self.trainModel(x, prediction, output_class, sess) + + # Since we don't yet support OpHints for dynamic, we will load the model + # back in as a static model. This requires the variables to have the same + # names as if they were trained as a static. Thus, we get rid of while/rnn + # names. + variables_to_save = {} + for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES): + op_name = i.name + if op_name.startswith("while/rnn/"): + op_name = op_name.split("while/rnn/")[1] + if op_name.endswith(":0"): + op_name = op_name.split(":0")[0] + variables_to_save[op_name] = i + saver = tf.train.Saver(variables_to_save) + + x, prediction, output_class, new_sess = self.saveAndRestoreModel( + self.buildLstmLayer(), sess, saver, is_dynamic_rnn=True) + + test_inputs, expected_output, frozen_graph = self.getInferenceResult( + x, output_class, new_sess) + + result = self.tfliteInvoke(frozen_graph, test_inputs, output_class) + self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-3)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/lite/experimental/micro/BUILD b/tensorflow/contrib/lite/experimental/micro/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..df1036bc8b9cc84f4b63ae2a771e3aa8f8989060 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/BUILD @@ -0,0 +1,76 @@ +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow/contrib/lite/experimental/micro/testing:micro_test.bzl", + "tflite_micro_cc_test", +) + +cc_library( + name = "micro_framework", + srcs = [ + "micro_error_reporter.cc", + "micro_interpreter.cc", + "micro_mutable_op_resolver.cc", + "simple_tensor_allocator.cc", + ], + hdrs = [ + "compatibility.h", + "micro_error_reporter.h", + "micro_interpreter.h", + "micro_mutable_op_resolver.h", + "simple_tensor_allocator.h", + ], + deps = [ + "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite/core/api", + "//tensorflow/contrib/lite/schema:schema_fbs", + ], +) + +tflite_micro_cc_test( + name = "micro_error_reporter_test", + srcs = [ + "micro_error_reporter_test.cc", + ], + deps = [ + ":micro_framework", + ], +) + +tflite_micro_cc_test( + name = "micro_mutable_op_resolver_test", + srcs = [ + "micro_mutable_op_resolver_test.cc", + ], + deps = [ + ":micro_framework", + "//tensorflow/contrib/lite/experimental/micro/testing:micro_test", + ], +) + +tflite_micro_cc_test( + name = "micro_interpreter_test", + srcs = [ + "micro_interpreter_test.cc", + ], + deps = [ + ":micro_framework", + "//tensorflow/contrib/lite/experimental/micro/testing:micro_test", + ], +) + +tflite_micro_cc_test( + name = "simple_tensor_allocator_test", + srcs = [ + "simple_tensor_allocator_test.cc", + ], + deps = [ + ":micro_framework", + "//tensorflow/contrib/lite/experimental/micro/testing:micro_test", + ], +) diff --git a/tensorflow/contrib/lite/experimental/micro/README.md b/tensorflow/contrib/lite/experimental/micro/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fc539db62e20b99324dd034bbc87d338349f102f --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/README.md @@ -0,0 +1,127 @@ +# TensorFlow Lite for Microcontrollers + +This an experimental port of TensorFlow Lite aimed at micro controllers and other devices with only kilobytes of memory. It doesn't require any operating system support, any standard C or C++ libraries, or dynamic memory allocation, so it's designed to be portable even to 'bare metal' systems. The core runtime fits in 16KB on a Cortex M3, and with enough operators to run a speech keyword detection model, takes up a total of 22KB. + +The design goals are for the framework to be: + +- **Readable**: We want embedded software engineers to be able to understand what's required to run ML inference without having to study research papers. We've tried to keep the code base small, modular, and have reference implementations of all operations to help with this. + +- **Easy to modify**: We know that there are a lot of different platforms and requirements in the embedded world, and we don't expect to cover all of them in one framework. Instead, we're hoping that it can be a good starting point for developers to build on top of to meet their own needs. For example, we tried to make it easy to replace the implementations of key computational operators that are often crucial for performance, without having to touch the data flow and other runtime code. We want it to make more sense to use our workflow to handle things like model import and less-important operations, and customize the parts that matter, rather than having to reimplement everything in your own engine. + +- **Well-tested**: If you're modifying code, you need to know if your changes are correct. Having an easy way to test lets you develop much faster. To help there, we've written tests for all the components, and we've made sure that the tests can be run on almost any platform, with no dependencies apart from the ability to log text to a debug console somewhere. We also provide an easy way to run all the tests on-device as part of an automated test framework, and we use qemu/Renode emulation so that tests can be run even without physical devices present. + +- **Easy to integrate**: We want to be as open a system as possible, and use the best code available for each platform. To do that, we're going to rely on projects like [CMSIS-NN](https://www.keil.com/pack/doc/CMSIS/NN/html/index.html), [uTensor](https://github.com/uTensor/uTensor), and other vendor libraries to handle as much performance-critical code as possible. We know that there are an increasing number of options to accelerate neural networks on microcontrollers, so we're aiming to be a good host for deploying those hardware technologies too. + +- **Compatible**: We're using the same file schema, interpreter API, and kernel interface as regular TensorFlow Lite, so we leverage the large existing set of tools, documentation, and examples for the project. The biggest barrier to deploying ML models is getting them from a training environment into a form that's easy to run inference on, so we see reusing this rich ecosystem as being crucial to being easily usable. We also hope to integrate this experimental work back into the main codebase in the future. + +To meet those goals, we've made some tradeoffs: + +- **Simple C++**: To help with readability, our code is written in a modern version of C++, but we generally treat it as a "better C", rather relying on more complex features such as template meta-programming. As mentioned earlier, we avoid any use of dynamic memory allocation (new/delete) or the standard C/C++ libraries, so we believe this should still be fairly portable. It does mean that some older devices with C-only toolchains won't be supported, but we're hoping that the reference operator implementations (which are simple C-like functions) can still be useful in those cases. The interfaces are also designed to be C-only, so it should be possible to integrate the resulting library with pure C projects. + +- **Interpreted**: Code generation is a popular pattern for embedded code, because it gives standalone code that's easy to modify and step through, but we've chosen to go with an interpreted approach. In our internal microcontroller work we've found that using an extremely stripped-down interpreter with almost no dependencies gives us a lot of the same advantages, but is easier to maintain. For example, when new updates come out for the underlying library, you can just merge your local modifications in a single step, rather than having to regenerate new code and then patch in any changes you subsequently made. The coarse granularity of the interpreted primitives means that each operation call typically takes hundreds of thousands of instruction cycles at least, so we don't see noticeable performance gains from avoiding what's essentially a single switch statement at the interpreter level to call each operation. We're still working on improving the packaging though, for example we're considering having the ability to snapshot all the source files and headers used for a particular model, being able to compile the code and data together as a library, and then access it through a minimal set of C interface calls which hide the underlying complexity. + +- **Flatbuffers**: We represent our models using [the standard flatbuffer schema used by the rest of TensorFlow Lite](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/schema/schema.fbs), with the difference that we always keep it in read-only program memory (typically flash) rather than relying on having a file system to read it from. This is a good fit because flatbuffer's serialized format is designed to be mapped into memory without requiring any extra memory allocations or modifications to access it. All of the functions to read model values work directly on the serialized bytes, and large sections of data like weights are directly accessible as sequential C-style arrays of their data type, with no strides or unpacking needed. We do get a lot of value from using flatbuffers, but there is a cost in complexity. The flat buffer library code is all inline [inside the main headers](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/schema/schema_generated.h), but it isn't straightforward to inspect their implementations, and the model data structures aren't easy to comprehend from the debugger. The header for the schema itself also has to be periodically updated when new information is added to the file format, though we try to handle that transparently for most developers by checking in a pre-generated version. + +- **Code Duplication**: Some of the code in this prototype largely duplicates the logic in other parts of the TensorFlow Lite code base, for example the operator wrappers. We've tried to keep share as much as we can between the two interpreters, but there are some assumptions built into the original runtime that make this difficult. We'll be working on modularizing the main interpreter so that we can move to an entirely shared system. + +This initial preview release is designed to get early feedback, and is not intended to be a final product. It only includes enough operations to run a simple keyword recognition model, and the implementations are not optimized. We're hoping this will be a good way to get feedback and collaborate to improve the framework. + +## Getting Started + +Building requires a Linux or OS X machine. + + - Open a terminal + - Download the TensorFlow source with `git clone https://github.com/tensorflow` + - Enter the source root directory by running `cd tensorflow` + - Download the dependencies by running `tensorflow/contrib/lite/experimental/micro/tools/make/download_dependencies.sh`. This may take a few minutes + - Build and test the library with `make -f tensorflow/contrib/lite/experimental/micro/tools/make/Makefile test` + +You should see a series of compilation steps, followed by `~~~ALL TESTS +PASSED~~~` for the various tests of the code that it will run. If there's an +error, you should get an informative message from make about what went wrong. + +These tests are all built as simple binaries with few dependencies, so you can run them manually. For example, here's how to run the depthwise convolution test, and its output: + +``` +tensorflow/contrib/lite/experimental/micro/tools/make/gen/linux_x86_64/bin/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test + +Testing SimpleTest +Testing SimpleTestQuantized +Testing SimpleTestRelu +Testing SimpleTestReluQuantized +4/4 tests passed +~ALL TESTS PASSED~~~ +``` + +Looking at the [depthwise_conv_test.cc](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test.cc) code, you'll see a sequence that looks like this: + +``` +... +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(SimpleTest) { +... +} +... +TF_LITE_MICRO_TESTS_END +``` + +These macros work a lot like +[the Google test framework](https://github.com/google/googletest), but they +don't require any dependencies and just write results to stderr, rather than +aborting the program. If all the tests pass, then `~~~ALL TESTS PASSED~~~` is +output, and the test harness that runs the binary during the make process knows +that everything ran correctly. If there's an error, the lack of the expected +string lets the harness know that the test failed. + +So, why are we running tests in this complicated way? So far, we've been building binaries that run locally on the Mac OS or Linux machine you're building on, but this approach becomes important when we're targeting simple micro controller devices. + +## Building for the "Blue Pill" STM32F103 + +The goal of this library is to enable machine learning on resource-constrained micro controllers and DSPs, and as part of that we've targeted the ["Blue Pill" STM32F103-compatible development board](https://github.com/google/googletest) as a cheap and popular platform. It only has 20KB of RAM and 64KB of flash, so it's a good device to ensure we can run efficiently on small chips. + +It's fairly easy to [buy and wire up a physical board](https://github.com/google/stm32_bare_lib#wiring-up-your-blue-pill), but even if you don't have an actual device, the [Renode project](https://renode.io/) makes it easy to run a faithful emulation on your desktop machine. You'll need [Docker](https://www.docker.com/) installed, but once you have that set up, try running the following command: + +`make -f tensorflow/contrib/lite/experimental/micro/tools/make/Makefile TARGET=bluepill test` + +You should see a similar set of outputs as you did in the previous section, with the addition of some extra Docker logging messages. These are because we're using Docker to run the Renode micro controller emulation tool, and the tests themselves are being run on a simulated STM32F103 device. The communication channels between an embedded device and the host are quite limited, so the test harness looks at the output of the debug log to see if tests have passed, just as it did in the previous section. This makes it a very flexible way to run cross-platform tests, even when a platform has no operating system facilities, as long as it can output debugging text logs. + +To understand what's happening here, try running the same depthwise convolution test, but through the emulated device test harness, with the following command: + +``` +tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh \ +tensorflow/contrib/lite/experimental/micro/tools/make/gen/bluepill_cortex-m3/bin/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test + +``` + +You should see output that looks something like this: + +``` +Sending build context to Docker daemon 21.5kB +Step 1/2 : FROM antmicro/renode:latest + ---> 1b670a243e8f +Step 2/2 : LABEL maintainer="Pete Warden " + ---> Using cache + ---> 3afcd410846d +Successfully built 3afcd410846d +Successfully tagged renode_bluepill:latest +LOGS: +... +03:27:32.4340 [INFO] machine-0: Machine started. +03:27:32.4790 [DEBUG] cpu.uartSemihosting: [+0.22s host +0s virt 0s virt from start] Testing SimpleTest +03:27:32.4812 [DEBUG] cpu.uartSemihosting: [+2.21ms host +0s virt 0s virt from start] Testing SimpleTestQuantized +03:27:32.4833 [DEBUG] cpu.uartSemihosting: [+2.14ms host +0s virt 0s virt from start] Testing SimpleTestRelu +03:27:32.4834 [DEBUG] cpu.uartSemihosting: [+0.18ms host +0s virt 0s virt from start] Testing SimpleTestReluQuantized +03:27:32.4838 [DEBUG] cpu.uartSemihosting: [+0.4ms host +0s virt 0s virt from start] 4/4 tests passed +03:27:32.4839 [DEBUG] cpu.uartSemihosting: [+41µs host +0s virt 0s virt from start] ~~~ALL TESTS PASSED~~~ +03:27:32.4839 [DEBUG] cpu.uartSemihosting: [+5µs host +0s virt 0s virt from start] +... +tensorflow/contrib/lite/experimental/micro/tools/make/gen/bluepill_cortex-m3/bin/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test: PASS +``` + +There's a lot of output here, but you should be able to see that the same tests +that were covered when we ran locally on the development machine show up in the +debug logs here, along with the magic string `~~~ALL TESTS PASSED~~~`. This is +the exact same code as before, just compiled and run on the STM32F103 rather +than your desktop. We hope that the simplicity of this testing approach will +help make adding support for new platforms as easy as possible. diff --git a/tensorflow/contrib/lite/experimental/micro/compatibility.h b/tensorflow/contrib/lite/experimental/micro/compatibility.h new file mode 100644 index 0000000000000000000000000000000000000000..4f0fd9f3120a5db74cdfb84e7b17a0f3656520bc --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/compatibility.h @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_COMPATIBILITY_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_COMPATIBILITY_H_ + +// C++ will automatically create class-specific delete operators for virtual +// objects, which by default call the global delete function. For embedded +// applications we want to avoid this, and won't be calling new/delete on these +// objects, so we need to override the default implementation with one that does +// nothing to avoid linking in ::delete(). +// This macro needs to be included in all subclasses of a virtual base class in +// the private section. +#ifdef TF_LITE_STATIC_MEMORY +#define TF_LITE_REMOVE_VIRTUAL_DELETE \ + void operator delete(void* p) {} +#else +#define TF_LITE_REMOVE_VIRTUAL_DELETE +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_COMPATIBILITY_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/BUILD b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..626f733540264c6fa13ab82557b822690b2d5b8f --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/BUILD @@ -0,0 +1,35 @@ +# Description: +# TensorFlow Lite microcontroller example. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow/contrib/lite/experimental/micro/testing:micro_test.bzl", + "tflite_micro_cc_test", +) + +tflite_micro_cc_test( + name = "micro_speech_test", + srcs = [ + "micro_speech_test.cc", + "no_features_data.cc", + "no_features_data.h", + "tiny_conv_model_data.cc", + "tiny_conv_model_data.h", + "yes_features_data.cc", + "yes_features_data.h", + ], + tags = [ + "nomsan", + ], + deps = [ + "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite/experimental/micro:micro_framework", + "//tensorflow/contrib/lite/experimental/micro/kernels:all_ops_resolver", + "//tensorflow/contrib/lite/experimental/micro/kernels:micro_ops", + "//tensorflow/contrib/lite/experimental/micro/testing:micro_test", + "//tensorflow/contrib/lite/schema:schema_fbs", + ], +) diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/README.md b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/README.md new file mode 100644 index 0000000000000000000000000000000000000000..438a432356be5c3cc9bfd08de5bd4d6f797c7014 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/README.md @@ -0,0 +1,103 @@ +# Micro Speech Example + +This examples shows how you can use TensorFlow Lite to run a 20 kilobyte neural network model to recognize keywords in speech. It's designed to run on systems with very small amounts of memory such as microcontrollers and DSPs. The code itself also has a small footprint (for example around 22 kilobytes on a Cortex M3) and only uses about 10 kilobytes of RAM for working memory, so it's able to run on systems like an STM32F103 with only 20 kilobytes of total SRAM and 64 kilobytes of Flash. + +## Table of Contents + + * [Getting Started](#getting-started) + * [Getting Started on a Microcontroller](#getting-started-on-a-microcontroller) + * [Calculating the Input to the Neural Network](#calculating-the-input-to-the-neural-network) + * [Creating Your Own Model](#creating-your-own-model) + +## Getting Started + +To compile and test this example on a desktop Linux or MacOS machine, download [the TensorFlow source code](https://github.com/tensorflow/tensorflow), `cd` into the source directory from a terminal, and then retrieve the support libraries you need by running: + +``` +tensorflow/contrib/lite/experimental/micro/tools/make/download_dependencies.sh +``` + +This will take a few minutes, and downloads frameworks the code uses like [CMSIS](https://developer.arm.com/embedded/cmsis) and [flatbuffers](https://google.github.io/flatbuffers/). Once that process has finished, run: + +``` +make -f tensorflow/contrib/lite/experimental/micro/tools/make/Makefile test_micro_speech +``` + +You should see a series of files get compiled, followed by some logging output from a test, which should conclude with "~~~ALL TESTS PASSED~~~". If you see this, it means that a small program has been built and run that loads a trained TensorFlow model, runs some example inputs through it, and got the expected outputs. This particular test runs spectrograms generated from recordings of people saying "Yes" and "No", and checks that the network correctly identifies them. + +To understand how TensorFlow Lite does this, you can look at the `TestInvoke()` function in [micro_speech_test.cc](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc). It's a fairly small amount of code, creating an interpreter, getting a handle to a model that's been compiled into the program, and then invoking the interpreter with the model and sample inputs. + +## Getting Started on a Microcontroller + +Once you have downloaded the dependencies and got the x86/Linux build working, you can try building a version for the STM32F103 'bluepill' device. The following command will build the test and then run it on an emulator, assuming you have Docker installed: + +``` +make -f tensorflow/contrib/lite/experimental/micro/tools/make/Makefile TARGET=bluepill test_micro_speech +``` + +If you have a real device [(see here for how to set one up)](https://github.com/google/stm32_bare_lib/tree/master/README.md) you can then convert the ELF file into a a `.bin` format executable to load onto it by running: + +``` +arm-none-eabi-objcopy \ +tensorflow/contrib/lite/experimental/micro/tools/make/gen/bluepill_cortex-m3/bin/micro_speech_test \ +tensorflow/contrib/lite/experimental/micro/tools/make/gen/bluepill_cortex-m3/bin/micro_speech_test.bin \ +--output binary +``` + +## Calculating the Input to the Neural Network + +The TensorFlow Lite model doesn't take in raw audio sample data. Instead it works with spectrograms, which are two dimensional arrays that are made up of slices of frequency information, each taken from a different time window. This test uses spectrograms that have been pre-calculated from one-second WAV files in the test data set. In a complete application these spectrograms would be calculated at runtime from microphone inputs, but the code for doing that is not yet included in this sample code. + +The recipe for creating the spectrogram data is that each frequency slice is created by running an FFT across a 30ms section of the audio sample data. The input samples are treated as being between -1 and +1 as real values (encoded as -32,768 and 32,767 in 16-bit signed integer samples). This results in an FFT with 256 entries. Every sequence of six entries is averaged together, giving a total of 43 frequency buckets in the final slice. The results are stored as unsigned eight-bit values, where 0 represents a real number of zero, and 255 represents 127.5 as a real number. Each adjacent frequency entry is stored in ascending memory order (frequency bucket 0 at data[0], bucket 1 at data [1], etc). The window for the frequency analysis is then moved forward by 20ms, and the process repeated, storing the results in the next memory row (for example bucket 0 in this moved window would be in data[43 + 0], etc). This process happens 49 times in total, producing a single channel image that is 43 pixels wide, and 49 rows high. Here's an illustration of the process: + +![spectrogram diagram](https://storage.googleapis.com/download.tensorflow.org/example_images/spectrogram_diagram.png) + + +The test data files have been generated by running the following commands: + +``` +bazel run tensorflow/examples/speech_commands:wav_to_features -- \ +--input_wav=${HOME}/speech_commands_test_set_v0.02/yes/f2e59fea_nohash_1.wav \ +--output_c_file=yes_features_data.cc \ +--window_stride=20 --preprocess=average --quantize=1 + +bazel run tensorflow/examples/speech_commands:wav_to_features -- \ +--input_wav=${HOME}/speech_commands_test_set_v0.02/no/f9643d42_nohash_4.wav \ +--output_c_file=no_features_data.cc \ +--window_stride=20 --preprocess=average --quantize=1 +``` + +## Creating Your Own Model + +The neural network model used in this example was built using the [TensorFlow speech commands tutorial](https://www.tensorflow.org/tutorials/sequences/audio_recognition). If you would like to create your own, you can start by training a model with this command: + +``` +bazel run -c opt --copt=-mavx2 --copt=-mfma \ +tensorflow/examples/speech_commands:train -- \ +--model_architecture=tiny_conv --window_stride=20 --preprocess=average \ +--wanted_words="yes,no" --silence_percentage=25 --unknown_percentage=25 --quantize=1 +``` + +If you see a compiling error on older machines, try leaving out the `--copt` arguments, they are just there to accelerate training on chips that support the extensions. The training process is likely to take a couple of hours. Once it has completed, the next step is to freeze the variables: + +``` +bazel run tensorflow/examples/speech_commands:freeze -- \ +--model_architecture=tiny_conv --window_stride=20 --preprocess=average \ +--wanted_words="yes,no" --quantize=1 --output_file=/tmp/tiny_conv.pb +``` + +The next step is to create a TensorFlow Lite file from the frozen graph: + +``` +bazel run tensorflow/contrib/lite/toco:toco -- \ +--input_file=/tmp/tiny_conv.pb --output_file=/tmp/tiny_conv.tflite \ +--input_shapes=1,49,43,1 --input_arrays=Reshape_1 --output_arrays='labels_softmax' \ +--inference_type=QUANTIZED_UINT8 --mean_values=0 --std_values=2 \ +--change_concat_input_ranges=false +``` + +Finally, convert the file into a C source file that can be compiled into an embedded system: + +``` +xxd -i /tmp/tiny_conv.tflite > /tmp/tiny_conv_model_data.cc +``` diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0f4731fd4b2a0890bb29d818145f34affde8f304 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc @@ -0,0 +1,137 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/micro/examples/micro_speech/no_features_data.h" +#include "tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h" +#include "tensorflow/contrib/lite/experimental/micro/examples/micro_speech/yes_features_data.h" +#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h" +#include "tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h" +#include "tensorflow/contrib/lite/experimental/micro/micro_interpreter.h" +#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/version.h" + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(TestInvoke) { + // Set up logging. + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + // Map the model into a usable data structure. This doesn't involve any + // copying or parsing, it's a very lightweight operation. + const tflite::Model* model = ::tflite::GetModel(g_tiny_conv_model_data); + if (model->version() != TFLITE_SCHEMA_VERSION) { + error_reporter->Report( + "Model provided is schema version %d not equal " + "to supported version %d.\n", + model->version(), TFLITE_SCHEMA_VERSION); + } + + // This pulls in all the operation implementations we need. + tflite::ops::micro::AllOpsResolver resolver; + + // Create an area of memory to use for input, output, and intermediate arrays. + const int tensor_arena_size = 10 * 1024; + uint8_t tensor_arena[tensor_arena_size]; + tflite::SimpleTensorAllocator tensor_allocator(tensor_arena, + tensor_arena_size); + + // Build an interpreter to run the model with. + tflite::MicroInterpreter interpreter(model, resolver, &tensor_allocator, + error_reporter); + + // Get information about the memory area to use for the model's input. + TfLiteTensor* input = interpreter.input(0); + + // Make sure the input has the properties we expect. + TF_LITE_MICRO_EXPECT_NE(nullptr, input); + TF_LITE_MICRO_EXPECT_EQ(4, input->dims->size); + TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]); + TF_LITE_MICRO_EXPECT_EQ(49, input->dims->data[1]); + TF_LITE_MICRO_EXPECT_EQ(43, input->dims->data[2]); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteUInt8, input->type); + + // Copy a spectrogram created from a .wav audio file of someone saying "Yes", + // into the memory area used for the input. + const uint8_t* yes_features_data = g_yes_f2e59fea_nohash_1_data; + for (int i = 0; i < input->bytes; ++i) { + input->data.uint8[i] = yes_features_data[i]; + } + + // Run the model on this input and make sure it succeeds. + TfLiteStatus invoke_status = interpreter.Invoke(); + if (invoke_status != kTfLiteOk) { + error_reporter->Report("Invoke failed\n"); + } + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status); + + // Get the output from the model, and make sure it's the expected size and + // type. + TfLiteTensor* output = interpreter.output(0); + TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size); + TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]); + TF_LITE_MICRO_EXPECT_EQ(4, output->dims->data[1]); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteUInt8, output->type); + + // There are four possible classes in the output, each with a score. + const int kSilenceIndex = 0; + const int kUnknownIndex = 1; + const int kYesIndex = 2; + const int kNoIndex = 3; + + // Make sure that the expected "Yes" score is higher than the other classes. + uint8_t silence_score = output->data.uint8[kSilenceIndex]; + uint8_t unknown_score = output->data.uint8[kUnknownIndex]; + uint8_t yes_score = output->data.uint8[kYesIndex]; + uint8_t no_score = output->data.uint8[kNoIndex]; + TF_LITE_MICRO_EXPECT_GT(yes_score, silence_score); + TF_LITE_MICRO_EXPECT_GT(yes_score, unknown_score); + TF_LITE_MICRO_EXPECT_GT(yes_score, no_score); + + // Now test with a different input, from a recording of "No". + const uint8_t* no_features_data = g_no_f9643d42_nohash_4_data; + for (int i = 0; i < input->bytes; ++i) { + input->data.uint8[i] = no_features_data[i]; + } + + // Run the model on this "No" input. + invoke_status = interpreter.Invoke(); + if (invoke_status != kTfLiteOk) { + error_reporter->Report("Invoke failed\n"); + } + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status); + + // Get the output from the model, and make sure it's the expected size and + // type. + output = interpreter.output(0); + TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size); + TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]); + TF_LITE_MICRO_EXPECT_EQ(4, output->dims->data[1]); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteUInt8, output->type); + + // Make sure that the expected "No" score is higher than the other classes. + silence_score = output->data.uint8[kSilenceIndex]; + unknown_score = output->data.uint8[kUnknownIndex]; + yes_score = output->data.uint8[kYesIndex]; + no_score = output->data.uint8[kNoIndex]; + TF_LITE_MICRO_EXPECT_GT(no_score, silence_score); + TF_LITE_MICRO_EXPECT_GT(no_score, unknown_score); + TF_LITE_MICRO_EXPECT_GT(no_score, yes_score); + + error_reporter->Report("Ran successfully\n"); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/no_features_data.cc b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/no_features_data.cc new file mode 100644 index 0000000000000000000000000000000000000000..3615deb26c4f0ea0b3018a6144e7f2cc58cd8a1e --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/no_features_data.cc @@ -0,0 +1,152 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/micro/examples/micro_speech/no_features_data.h" + +/* File automatically created by + * tensorflow/examples/speech_commands/wav_to_features.py \ + * --sample_rate=16000 \ + * --clip_duration_ms=1000 \ + * --window_size_ms=30 \ + * --window_stride_ms=20 \ + * --feature_bin_count=40 \ + * --quantize \ + * --preprocess="average" \ + * --input_wav="speech_commands_test_set_v0.02/no/f9643d42_nohash_4.wav" \ + * --output_c_file="no_features_data.cc" \ + */ + +const int g_no_f9643d42_nohash_4_width = 43; +const int g_no_f9643d42_nohash_4_height = 49; +const unsigned char g_no_f9643d42_nohash_4_data[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 67, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 139, 2, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 195, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 230, 2, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 7, + 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 255, 7, 16, 1, 1, 0, 2, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 7, 22, 0, 1, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 238, 5, 20, 3, 4, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 144, 4, 19, 3, 5, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 42, 6, 3, + 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 3, 1, 5, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 1, 3, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, +}; diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/no_features_data.h b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/no_features_data.h new file mode 100644 index 0000000000000000000000000000000000000000..b53d0a202b75eab7db82107f2c71c504a85f881e --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/no_features_data.h @@ -0,0 +1,23 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_NO_FEATURES_DATA_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_NO_FEATURES_DATA_H_ + +extern const int g_no_f9643d42_nohash_4_width; +extern const int g_no_f9643d42_nohash_4_height; +extern const unsigned char g_no_f9643d42_nohash_4_data[]; + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_NO_FEATURES_DATA_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc new file mode 100644 index 0000000000000000000000000000000000000000..f0769a1237d64a5f727ec86c5d8ff2e20086436d --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc @@ -0,0 +1,1673 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Automatically created from a TensorFlow Lite flatbuffer using the command: +// xxd -i tiny_conv.tflite > tiny_conv_model_data.cc +// See the README for a full description of the creation process. + +#include "tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h" + +const unsigned char g_tiny_conv_model_data[] = { + 0x18, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c, 0x33, 0x00, 0x00, 0x0e, 0x00, + 0x18, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x14, 0x00, + 0x0e, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x08, 0x4d, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0xf4, 0x47, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, + 0x54, 0x4f, 0x43, 0x4f, 0x20, 0x43, 0x6f, 0x6e, 0x76, 0x65, 0x72, 0x74, + 0x65, 0x64, 0x2e, 0x00, 0x09, 0x00, 0x00, 0x00, 0xd4, 0x47, 0x00, 0x00, + 0xb4, 0x47, 0x00, 0x00, 0xe4, 0x02, 0x00, 0x00, 0xb4, 0x02, 0x00, 0x00, + 0xac, 0x02, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xb8, 0xb3, 0xff, 0xff, + 0xbc, 0xb3, 0xff, 0xff, 0xc0, 0xb3, 0xff, 0xff, 0x1e, 0xb4, 0xff, 0xff, + 0x04, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x00, 0x89, 0xa5, 0xe8, 0xc1, + 0xb1, 0x89, 0x5b, 0xc6, 0x4f, 0x9b, 0xd3, 0x74, 0x93, 0x88, 0xff, 0xaf, + 0x89, 0xff, 0xf4, 0x70, 0xcc, 0x75, 0x78, 0xbf, 0x92, 0xcd, 0xa9, 0xa8, + 0xd6, 0x6a, 0x6f, 0x7b, 0x7f, 0xd8, 0xa8, 0xb1, 0xe6, 0x32, 0x21, 0x70, + 0xa0, 0x9c, 0x6f, 0xc8, 0xc6, 0x59, 0x67, 0x93, 0x97, 0xca, 0x3f, 0xde, + 0xcb, 0x74, 0x7c, 0xb5, 0xa4, 0xd9, 0x66, 0xc6, 0x87, 0x98, 0xa5, 0xd0, + 0xbb, 0xb9, 0xc2, 0xb2, 0xaa, 0x79, 0x25, 0xb9, 0x6d, 0x5a, 0xc8, 0x7f, + 0x70, 0x85, 0x79, 0xbc, 0x6a, 0x9b, 0xd1, 0x9a, 0x9c, 0x51, 0x53, 0x71, + 0x89, 0xc0, 0xb4, 0xac, 0xae, 0x47, 0x67, 0x70, 0x79, 0xd2, 0x81, 0xa5, + 0xd2, 0x09, 0x38, 0x82, 0x74, 0xc9, 0x5d, 0xaf, 0xc1, 0x4f, 0x53, 0x99, + 0xcb, 0xb7, 0x3a, 0xba, 0xe8, 0x7f, 0x76, 0xb9, 0xb3, 0xd3, 0x60, 0xc0, + 0x93, 0x9f, 0x87, 0xbd, 0xd0, 0xb8, 0xca, 0xc1, 0xb6, 0x6c, 0x01, 0xc1, + 0x5c, 0x5d, 0xb2, 0x82, 0x76, 0x77, 0x39, 0xbc, 0x72, 0x6a, 0xc3, 0xb4, + 0x79, 0x21, 0x48, 0x42, 0x86, 0xa6, 0xbd, 0xaf, 0xae, 0x23, 0x9c, 0x69, + 0x78, 0xc3, 0x6b, 0xb3, 0xab, 0x43, 0xb2, 0x88, 0x71, 0xc6, 0x6b, 0xbe, + 0xc3, 0x75, 0xc2, 0xc3, 0xa5, 0xcf, 0x32, 0xbe, 0xcb, 0xb0, 0xb8, 0xc1, + 0x9c, 0xcf, 0x64, 0xc4, 0xb4, 0x96, 0xa8, 0xb9, 0xcb, 0xc0, 0xc0, 0xb8, + 0xb8, 0x77, 0x65, 0xc0, 0xc4, 0xb3, 0xc5, 0x77, 0x9b, 0x61, 0xd4, 0xac, + 0x7e, 0x36, 0xb1, 0xae, 0x36, 0x36, 0xb8, 0x39, 0x6b, 0x70, 0x9c, 0xb5, + 0x88, 0x5c, 0xb3, 0x6a, 0xad, 0xc5, 0x7b, 0xb4, 0xad, 0xaa, 0xc4, 0x84, + 0x5e, 0xc4, 0x67, 0xc1, 0xde, 0xba, 0xcf, 0xbd, 0xa0, 0xd3, 0x35, 0xb3, + 0xe7, 0xc8, 0xb8, 0xb8, 0xaf, 0xb4, 0x59, 0xb8, 0xb4, 0xac, 0xac, 0xaa, + 0xc7, 0xad, 0xc8, 0xb6, 0xac, 0x99, 0xa0, 0xcb, 0xc1, 0xc8, 0xcb, 0x89, + 0xc3, 0xac, 0xca, 0x8b, 0x97, 0x1f, 0xbd, 0xbf, 0x13, 0xad, 0xc8, 0x41, + 0x56, 0x3c, 0x86, 0xb2, 0x61, 0xc4, 0xbb, 0x71, 0xba, 0x92, 0x8d, 0xc3, + 0x86, 0xcb, 0xc5, 0x8d, 0x88, 0xc8, 0x6a, 0xbf, 0x9c, 0xcd, 0xcd, 0xc0, + 0x81, 0xb1, 0x47, 0xb5, 0xf0, 0xce, 0xb1, 0xc1, 0xaa, 0xa8, 0x54, 0xcb, + 0xbc, 0xc7, 0xc5, 0x8e, 0xc3, 0xce, 0xc7, 0xb9, 0xb9, 0xa1, 0xc5, 0xbd, + 0xb8, 0xb8, 0xb7, 0x81, 0xb6, 0xba, 0xd2, 0x90, 0xbc, 0x96, 0xbe, 0xba, + 0x53, 0xb5, 0xc7, 0x3c, 0x3c, 0x1f, 0x90, 0xaa, 0x5a, 0xb8, 0xba, 0x7e, + 0xbc, 0x9e, 0xc2, 0xb1, 0x6e, 0xc0, 0xc4, 0x91, 0xf0, 0xb5, 0x60, 0xad, + 0x73, 0xba, 0xcd, 0xba, 0x6e, 0x94, 0x39, 0xb5, 0xe4, 0xbe, 0xb4, 0xb5, + 0xa0, 0xa9, 0x51, 0xac, 0xbc, 0xc2, 0xb3, 0x8a, 0xbd, 0x9a, 0xca, 0xb3, + 0xbf, 0xaf, 0xb5, 0x9a, 0xb9, 0xc3, 0xb6, 0x92, 0xb5, 0xc1, 0xb0, 0x95, + 0xd6, 0xcc, 0xbb, 0xbb, 0xa9, 0xb9, 0xac, 0x4a, 0x62, 0x27, 0xa7, 0xa7, + 0x30, 0xbd, 0xb1, 0x73, 0xa1, 0x74, 0xc2, 0xb7, 0x58, 0xc0, 0xae, 0x8f, + 0xe1, 0xac, 0x4e, 0xb0, 0x55, 0xc9, 0xc8, 0x9f, 0x83, 0x8e, 0x3e, 0xd5, + 0xb5, 0xbe, 0xcd, 0xb2, 0xa6, 0xc8, 0x64, 0xac, 0xc0, 0xc8, 0xaf, 0x99, + 0xc5, 0x9e, 0xb8, 0xbd, 0xa9, 0xc2, 0xb3, 0x81, 0xb4, 0xc2, 0xb4, 0x8f, + 0xbc, 0xb8, 0x9c, 0x88, 0xbe, 0xc6, 0xbf, 0xba, 0xc8, 0xb4, 0xab, 0x5b, + 0x92, 0x51, 0xb1, 0x9a, 0x44, 0xb9, 0xab, 0x80, 0xa5, 0x3e, 0xc0, 0xa5, + 0x5c, 0xb6, 0xa8, 0xa2, 0xb3, 0x9a, 0x6b, 0xb3, 0x34, 0xc6, 0x7e, 0x96, + 0xcb, 0x88, 0x48, 0xc6, 0xa3, 0xbb, 0xd2, 0xa2, 0xaf, 0xd0, 0x6e, 0xae, + 0xb4, 0xce, 0xc8, 0x8f, 0xd7, 0xad, 0xc8, 0xb0, 0xae, 0xb7, 0xb2, 0x70, + 0xb9, 0xad, 0xc1, 0xa0, 0xcb, 0xa2, 0xb0, 0x9b, 0xbe, 0xd3, 0xca, 0xb6, + 0xbd, 0xaf, 0xa9, 0x82, 0xa1, 0xd7, 0xbc, 0x9b, 0x8b, 0xac, 0xaa, 0xac, + 0xad, 0x37, 0xb7, 0xb6, 0x46, 0xae, 0xa9, 0xbd, 0x6b, 0x90, 0x5e, 0xcd, + 0x23, 0xa4, 0x76, 0xa1, 0xc4, 0x96, 0x50, 0xcc, 0x95, 0x99, 0x93, 0xa7, + 0xb2, 0xe1, 0x7c, 0xbd, 0xbd, 0xb5, 0xbf, 0x9a, 0xca, 0x80, 0xd7, 0xae, + 0x79, 0xa8, 0xaa, 0xb2, 0xbc, 0x51, 0xda, 0xa3, 0x80, 0x8b, 0xa2, 0xc8, + 0xd1, 0x94, 0xe1, 0xc4, 0xbd, 0xae, 0xae, 0xcc, 0xb3, 0xca, 0xd5, 0xa1, + 0xd5, 0xa7, 0xaf, 0xd2, 0xb4, 0x8d, 0xcc, 0xc8, 0x63, 0xa3, 0xa4, 0xdf, + 0x6f, 0x7e, 0x98, 0xdf, 0x1b, 0x7b, 0x43, 0x99, 0xb0, 0x99, 0x71, 0xdb, + 0x63, 0x7b, 0x69, 0x9c, 0xba, 0xcd, 0x90, 0xd0, 0xb6, 0xa6, 0x9e, 0x95, + 0x50, 0xb6, 0xff, 0xff, 0xae, 0xb6, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0xc7, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x77, 0x00, 0x00, 0x00, + 0xda, 0xb6, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0xc0, 0x44, 0x00, 0x00, + 0x2c, 0x30, 0x38, 0x5a, 0x3d, 0x4c, 0x44, 0x3b, 0x48, 0x48, 0x44, 0x57, + 0x3f, 0x43, 0x45, 0x3a, 0x24, 0x32, 0x21, 0x5c, 0x3f, 0x3a, 0x38, 0x3a, + 0x35, 0x35, 0x2f, 0x51, 0x3c, 0x3a, 0x45, 0x3a, 0x3b, 0x41, 0x39, 0x55, + 0x3c, 0x41, 0x39, 0x44, 0x3a, 0x40, 0x37, 0x48, 0x33, 0x47, 0x36, 0x3e, + 0x3c, 0x41, 0x3f, 0x3e, 0x3e, 0x47, 0x36, 0x3e, 0x41, 0x33, 0x3e, 0x3b, + 0x3a, 0x46, 0x45, 0x40, 0x48, 0x3a, 0x35, 0x4b, 0x45, 0x4d, 0x3c, 0x49, + 0x42, 0x44, 0x3c, 0x4c, 0x3e, 0x3c, 0x44, 0x32, 0x33, 0x41, 0x36, 0x4b, + 0x38, 0x3b, 0x3c, 0x38, 0x3b, 0x45, 0x34, 0x46, 0x40, 0x4e, 0x44, 0x35, + 0x43, 0x36, 0x3d, 0x40, 0x3e, 0x48, 0x40, 0x34, 0x3a, 0x46, 0x45, 0x43, + 0x45, 0x3f, 0x47, 0x37, 0x36, 0x35, 0x44, 0x3a, 0x3e, 0x37, 0x39, 0x40, + 0x3a, 0x3f, 0x3f, 0x4c, 0x3e, 0x41, 0x43, 0x35, 0x3f, 0x3d, 0x3d, 0x4c, + 0x3c, 0x4a, 0x46, 0x3c, 0x3a, 0x41, 0x40, 0x4e, 0x36, 0x47, 0x40, 0x3b, + 0x47, 0x42, 0x38, 0x4d, 0x48, 0x47, 0x3c, 0x3c, 0x33, 0x3b, 0x3e, 0x42, + 0x3f, 0x3e, 0x3a, 0x3d, 0x32, 0x39, 0x41, 0x46, 0x3a, 0x3a, 0x3e, 0x3e, + 0x47, 0x48, 0x4e, 0x36, 0x44, 0x40, 0x41, 0x45, 0x3a, 0x3c, 0x38, 0x55, + 0x2e, 0x26, 0x2f, 0x32, 0x3f, 0x41, 0x3e, 0x4c, 0x45, 0x36, 0x40, 0x31, + 0x17, 0x2e, 0x14, 0x53, 0x34, 0x30, 0x34, 0x3f, 0x2e, 0x44, 0x2b, 0x4e, + 0x34, 0x3e, 0x34, 0x43, 0x3d, 0x35, 0x3f, 0x46, 0x39, 0x40, 0x38, 0x3e, + 0x35, 0x3b, 0x35, 0x45, 0x3d, 0x40, 0x38, 0x37, 0x40, 0x3e, 0x32, 0x3e, + 0x41, 0x39, 0x30, 0x41, 0x3a, 0x32, 0x3e, 0x3d, 0x39, 0x31, 0x33, 0x3e, + 0x41, 0x47, 0x40, 0x47, 0x35, 0x33, 0x3c, 0x32, 0x40, 0x3c, 0x42, 0x49, + 0x34, 0x38, 0x39, 0x37, 0x39, 0x35, 0x40, 0x4d, 0x37, 0x43, 0x42, 0x3e, + 0x3f, 0x3c, 0x3e, 0x51, 0x36, 0x37, 0x42, 0x41, 0x36, 0x31, 0x43, 0x3d, + 0x46, 0x43, 0x37, 0x46, 0x32, 0x45, 0x42, 0x36, 0x3f, 0x42, 0x42, 0x41, + 0x3d, 0x46, 0x39, 0x41, 0x3c, 0x3f, 0x38, 0x3c, 0x43, 0x43, 0x3d, 0x3c, + 0x3d, 0x41, 0x38, 0x42, 0x3a, 0x3d, 0x43, 0x42, 0x41, 0x40, 0x39, 0x36, + 0x3a, 0x3c, 0x3c, 0x4f, 0x44, 0x36, 0x39, 0x35, 0x46, 0x46, 0x36, 0x4a, + 0x3a, 0x42, 0x43, 0x39, 0x3f, 0x3d, 0x3c, 0x47, 0x38, 0x3f, 0x43, 0x40, + 0x36, 0x3c, 0x45, 0x3b, 0x33, 0x36, 0x3b, 0x39, 0x3c, 0x35, 0x40, 0x38, + 0x40, 0x3e, 0x3f, 0x48, 0x3f, 0x34, 0x40, 0x53, 0x26, 0x2c, 0x29, 0x39, + 0x2a, 0x38, 0x3f, 0x45, 0x32, 0x31, 0x4a, 0x37, 0x1c, 0x28, 0x09, 0x43, + 0x35, 0x3b, 0x33, 0x3c, 0x32, 0x3f, 0x28, 0x41, 0x36, 0x35, 0x3a, 0x37, + 0x41, 0x39, 0x32, 0x3c, 0x40, 0x3c, 0x3c, 0x32, 0x38, 0x39, 0x37, 0x44, + 0x3a, 0x33, 0x41, 0x36, 0x37, 0x3c, 0x35, 0x3a, 0x3d, 0x30, 0x3d, 0x41, + 0x37, 0x3c, 0x45, 0x3a, 0x37, 0x2f, 0x36, 0x3c, 0x3a, 0x3d, 0x39, 0x48, + 0x46, 0x33, 0x3a, 0x3e, 0x40, 0x3d, 0x3b, 0x52, 0x38, 0x45, 0x34, 0x47, + 0x39, 0x36, 0x37, 0x56, 0x42, 0x3f, 0x33, 0x36, 0x38, 0x3f, 0x40, 0x53, + 0x3e, 0x37, 0x3d, 0x3c, 0x48, 0x3a, 0x3d, 0x33, 0x39, 0x40, 0x3e, 0x35, + 0x3d, 0x46, 0x38, 0x36, 0x37, 0x43, 0x3a, 0x3c, 0x40, 0x38, 0x39, 0x3b, + 0x39, 0x3a, 0x42, 0x3d, 0x34, 0x3f, 0x35, 0x43, 0x3a, 0x35, 0x46, 0x3a, + 0x48, 0x38, 0x3b, 0x48, 0x3c, 0x35, 0x42, 0x3d, 0x3a, 0x3d, 0x38, 0x42, + 0x3e, 0x3c, 0x33, 0x39, 0x34, 0x30, 0x42, 0x44, 0x41, 0x3d, 0x3c, 0x39, + 0x3c, 0x3a, 0x39, 0x41, 0x3d, 0x44, 0x3c, 0x40, 0x3f, 0x3e, 0x42, 0x3f, + 0x37, 0x40, 0x39, 0x3b, 0x42, 0x43, 0x49, 0x37, 0x39, 0x46, 0x35, 0x3c, + 0x3e, 0x39, 0x45, 0x52, 0x24, 0x2d, 0x38, 0x35, 0x3a, 0x3a, 0x3c, 0x44, + 0x39, 0x32, 0x51, 0x3f, 0x16, 0x34, 0x0a, 0x49, 0x39, 0x38, 0x39, 0x3e, + 0x2f, 0x36, 0x24, 0x3f, 0x37, 0x34, 0x38, 0x3b, 0x34, 0x34, 0x30, 0x3b, + 0x3d, 0x36, 0x35, 0x42, 0x33, 0x40, 0x37, 0x35, 0x43, 0x3f, 0x3f, 0x39, + 0x3a, 0x43, 0x36, 0x3e, 0x39, 0x3d, 0x3f, 0x3d, 0x47, 0x3b, 0x39, 0x37, + 0x35, 0x42, 0x3f, 0x3b, 0x41, 0x3a, 0x42, 0x4b, 0x3d, 0x3f, 0x3d, 0x3e, + 0x38, 0x3b, 0x34, 0x4e, 0x3f, 0x39, 0x36, 0x43, 0x39, 0x35, 0x41, 0x4d, + 0x3c, 0x39, 0x43, 0x33, 0x37, 0x3b, 0x41, 0x48, 0x3c, 0x3f, 0x39, 0x32, + 0x35, 0x3d, 0x42, 0x35, 0x3d, 0x3e, 0x37, 0x3b, 0x38, 0x3a, 0x44, 0x36, + 0x42, 0x35, 0x48, 0x40, 0x3a, 0x44, 0x44, 0x39, 0x43, 0x41, 0x3c, 0x37, + 0x47, 0x3b, 0x42, 0x42, 0x45, 0x3a, 0x40, 0x46, 0x35, 0x3f, 0x3a, 0x48, + 0x35, 0x44, 0x3f, 0x37, 0x33, 0x3e, 0x45, 0x49, 0x39, 0x43, 0x47, 0x37, + 0x3f, 0x3f, 0x3b, 0x44, 0x38, 0x3d, 0x39, 0x42, 0x37, 0x3e, 0x40, 0x45, + 0x3b, 0x3f, 0x40, 0x34, 0x42, 0x3f, 0x43, 0x3c, 0x43, 0x41, 0x38, 0x38, + 0x38, 0x41, 0x55, 0x33, 0x33, 0x39, 0x39, 0x3c, 0x35, 0x39, 0x38, 0x42, + 0x27, 0x26, 0x32, 0x41, 0x41, 0x32, 0x3f, 0x47, 0x3a, 0x38, 0x48, 0x37, + 0x11, 0x27, 0x08, 0x49, 0x35, 0x42, 0x3c, 0x2e, 0x34, 0x43, 0x25, 0x3b, + 0x3a, 0x33, 0x37, 0x30, 0x3c, 0x36, 0x2d, 0x3c, 0x3b, 0x39, 0x3b, 0x40, + 0x46, 0x3a, 0x30, 0x42, 0x35, 0x32, 0x36, 0x3a, 0x3a, 0x34, 0x34, 0x33, + 0x3d, 0x30, 0x3b, 0x42, 0x41, 0x3f, 0x3d, 0x3b, 0x44, 0x3d, 0x41, 0x41, + 0x3d, 0x3f, 0x40, 0x51, 0x42, 0x42, 0x36, 0x45, 0x30, 0x40, 0x32, 0x4f, + 0x3a, 0x3c, 0x40, 0x39, 0x3d, 0x3b, 0x3e, 0x4b, 0x3d, 0x37, 0x42, 0x46, + 0x40, 0x40, 0x47, 0x3d, 0x35, 0x3c, 0x3f, 0x46, 0x37, 0x37, 0x3a, 0x2e, + 0x3d, 0x3c, 0x3a, 0x46, 0x3a, 0x44, 0x3c, 0x3a, 0x32, 0x44, 0x31, 0x41, + 0x43, 0x36, 0x49, 0x39, 0x3d, 0x37, 0x3f, 0x41, 0x3b, 0x3b, 0x3c, 0x42, + 0x3c, 0x34, 0x3f, 0x3b, 0x40, 0x3e, 0x48, 0x47, 0x3e, 0x3c, 0x38, 0x39, + 0x3f, 0x35, 0x39, 0x3f, 0x3e, 0x3e, 0x3b, 0x43, 0x41, 0x40, 0x43, 0x41, + 0x3f, 0x37, 0x39, 0x41, 0x46, 0x32, 0x3d, 0x41, 0x36, 0x3f, 0x3e, 0x3f, + 0x36, 0x48, 0x43, 0x3d, 0x43, 0x3f, 0x34, 0x3d, 0x34, 0x35, 0x4f, 0x32, + 0x3c, 0x3f, 0x3d, 0x3f, 0x39, 0x3c, 0x3d, 0x47, 0x23, 0x36, 0x33, 0x45, + 0x37, 0x2e, 0x42, 0x42, 0x39, 0x34, 0x4f, 0x3f, 0x19, 0x2b, 0x01, 0x50, + 0x35, 0x3f, 0x37, 0x3c, 0x33, 0x35, 0x25, 0x32, 0x38, 0x3e, 0x40, 0x40, + 0x2f, 0x38, 0x35, 0x3d, 0x31, 0x42, 0x44, 0x3c, 0x3a, 0x3d, 0x2d, 0x3e, + 0x3b, 0x3e, 0x3d, 0x31, 0x3b, 0x37, 0x35, 0x31, 0x36, 0x35, 0x34, 0x31, + 0x41, 0x3a, 0x33, 0x32, 0x3c, 0x31, 0x3e, 0x3d, 0x40, 0x3b, 0x34, 0x45, + 0x36, 0x39, 0x3e, 0x3f, 0x3c, 0x45, 0x37, 0x4b, 0x42, 0x3d, 0x33, 0x43, + 0x3e, 0x40, 0x35, 0x4e, 0x38, 0x36, 0x3a, 0x33, 0x38, 0x44, 0x3f, 0x3c, + 0x3f, 0x40, 0x3a, 0x3c, 0x3c, 0x3c, 0x44, 0x29, 0x3a, 0x40, 0x35, 0x3a, + 0x3d, 0x48, 0x3b, 0x30, 0x45, 0x41, 0x45, 0x40, 0x37, 0x32, 0x3a, 0x35, + 0x3f, 0x38, 0x3b, 0x43, 0x3b, 0x3f, 0x33, 0x40, 0x3b, 0x40, 0x38, 0x33, + 0x39, 0x3c, 0x3c, 0x3f, 0x43, 0x33, 0x43, 0x40, 0x43, 0x3d, 0x33, 0x42, + 0x40, 0x32, 0x3e, 0x36, 0x40, 0x38, 0x43, 0x40, 0x44, 0x38, 0x34, 0x3c, + 0x3e, 0x39, 0x47, 0x43, 0x40, 0x3b, 0x3f, 0x3f, 0x3c, 0x3b, 0x4b, 0x33, + 0x36, 0x49, 0x32, 0x41, 0x48, 0x45, 0x57, 0x3a, 0x40, 0x42, 0x40, 0x46, + 0x36, 0x35, 0x3c, 0x46, 0x22, 0x2e, 0x33, 0x3e, 0x3c, 0x39, 0x44, 0x4d, + 0x3f, 0x41, 0x51, 0x44, 0x15, 0x2e, 0x02, 0x4e, 0x39, 0x3a, 0x3c, 0x35, + 0x30, 0x38, 0x1e, 0x31, 0x40, 0x3b, 0x39, 0x3d, 0x3a, 0x37, 0x35, 0x36, + 0x46, 0x36, 0x3c, 0x3e, 0x39, 0x3e, 0x32, 0x40, 0x3b, 0x35, 0x42, 0x41, + 0x41, 0x38, 0x41, 0x35, 0x42, 0x36, 0x3c, 0x42, 0x3d, 0x41, 0x35, 0x31, + 0x3f, 0x44, 0x3e, 0x41, 0x3f, 0x35, 0x42, 0x4b, 0x3e, 0x36, 0x37, 0x34, + 0x36, 0x3d, 0x40, 0x49, 0x41, 0x3e, 0x3d, 0x3b, 0x38, 0x37, 0x40, 0x47, + 0x35, 0x32, 0x43, 0x38, 0x36, 0x3b, 0x33, 0x47, 0x33, 0x34, 0x3d, 0x47, + 0x3c, 0x37, 0x3d, 0x2b, 0x3a, 0x36, 0x3b, 0x3d, 0x43, 0x38, 0x35, 0x32, + 0x32, 0x37, 0x43, 0x36, 0x3f, 0x48, 0x38, 0x30, 0x3a, 0x3c, 0x42, 0x34, + 0x37, 0x3c, 0x37, 0x40, 0x48, 0x3e, 0x35, 0x3b, 0x3f, 0x38, 0x39, 0x3e, + 0x37, 0x35, 0x36, 0x3d, 0x3b, 0x3c, 0x40, 0x3d, 0x34, 0x40, 0x46, 0x42, + 0x3f, 0x3c, 0x3c, 0x3e, 0x40, 0x40, 0x3d, 0x3f, 0x3f, 0x44, 0x46, 0x41, + 0x32, 0x43, 0x40, 0x41, 0x3c, 0x42, 0x39, 0x38, 0x48, 0x44, 0x3d, 0x38, + 0x34, 0x40, 0x4e, 0x31, 0x3c, 0x42, 0x39, 0x48, 0x3c, 0x33, 0x3e, 0x40, + 0x20, 0x27, 0x39, 0x45, 0x45, 0x36, 0x47, 0x4c, 0x35, 0x3e, 0x4a, 0x36, + 0x16, 0x2f, 0x04, 0x4f, 0x3a, 0x35, 0x36, 0x3a, 0x2d, 0x36, 0x21, 0x34, + 0x3b, 0x32, 0x3d, 0x3c, 0x3c, 0x3f, 0x3b, 0x3b, 0x41, 0x46, 0x40, 0x3d, + 0x3b, 0x44, 0x33, 0x42, 0x34, 0x33, 0x3e, 0x45, 0x3f, 0x46, 0x39, 0x33, + 0x3b, 0x37, 0x37, 0x37, 0x42, 0x47, 0x3c, 0x35, 0x31, 0x41, 0x44, 0x3a, + 0x3b, 0x33, 0x39, 0x44, 0x42, 0x33, 0x3d, 0x3f, 0x43, 0x33, 0x41, 0x4a, + 0x35, 0x46, 0x36, 0x3e, 0x39, 0x41, 0x41, 0x4c, 0x34, 0x3d, 0x38, 0x33, + 0x3c, 0x3f, 0x43, 0x44, 0x37, 0x35, 0x35, 0x3c, 0x43, 0x34, 0x3e, 0x2d, + 0x3f, 0x35, 0x38, 0x3c, 0x33, 0x35, 0x43, 0x2a, 0x40, 0x33, 0x34, 0x40, + 0x3d, 0x38, 0x36, 0x2d, 0x36, 0x3c, 0x43, 0x3d, 0x37, 0x3d, 0x39, 0x38, + 0x3b, 0x3e, 0x3c, 0x46, 0x35, 0x35, 0x43, 0x44, 0x39, 0x40, 0x34, 0x39, + 0x3d, 0x34, 0x40, 0x45, 0x38, 0x35, 0x3e, 0x39, 0x3c, 0x44, 0x48, 0x44, + 0x41, 0x3e, 0x3c, 0x45, 0x3a, 0x3c, 0x3c, 0x46, 0x3a, 0x40, 0x39, 0x43, + 0x35, 0x35, 0x3e, 0x45, 0x3a, 0x34, 0x3c, 0x39, 0x46, 0x3a, 0x4f, 0x35, + 0x32, 0x3d, 0x36, 0x41, 0x32, 0x38, 0x3f, 0x45, 0x2d, 0x34, 0x2a, 0x35, + 0x43, 0x3f, 0x41, 0x49, 0x41, 0x3c, 0x4b, 0x3f, 0x17, 0x31, 0x02, 0x4f, + 0x30, 0x38, 0x39, 0x40, 0x33, 0x3a, 0x25, 0x38, 0x35, 0x3c, 0x39, 0x35, + 0x34, 0x41, 0x34, 0x43, 0x40, 0x40, 0x46, 0x3d, 0x40, 0x38, 0x3f, 0x3b, + 0x35, 0x39, 0x3c, 0x39, 0x34, 0x38, 0x3f, 0x36, 0x3a, 0x38, 0x44, 0x3f, + 0x3f, 0x38, 0x3c, 0x33, 0x41, 0x42, 0x38, 0x33, 0x3c, 0x3b, 0x3c, 0x46, + 0x38, 0x3b, 0x3f, 0x33, 0x3f, 0x48, 0x3b, 0x49, 0x3f, 0x3a, 0x3d, 0x3f, + 0x47, 0x3d, 0x30, 0x45, 0x36, 0x42, 0x3d, 0x36, 0x43, 0x38, 0x3b, 0x3d, + 0x3c, 0x30, 0x3b, 0x43, 0x3d, 0x41, 0x34, 0x2e, 0x43, 0x3d, 0x43, 0x46, + 0x43, 0x3c, 0x3c, 0x2e, 0x3c, 0x43, 0x34, 0x43, 0x3e, 0x43, 0x3f, 0x2b, + 0x45, 0x40, 0x3a, 0x43, 0x36, 0x39, 0x3f, 0x3d, 0x3a, 0x3c, 0x35, 0x3b, + 0x36, 0x3f, 0x45, 0x3e, 0x45, 0x40, 0x3f, 0x36, 0x45, 0x42, 0x35, 0x3e, + 0x3a, 0x3a, 0x3f, 0x40, 0x3e, 0x3c, 0x39, 0x46, 0x43, 0x3e, 0x3f, 0x3f, + 0x40, 0x3c, 0x40, 0x4b, 0x41, 0x35, 0x3b, 0x3e, 0x49, 0x32, 0x3e, 0x41, + 0x31, 0x37, 0x3d, 0x3b, 0x3f, 0x45, 0x50, 0x3a, 0x3f, 0x3c, 0x44, 0x36, + 0x43, 0x37, 0x3d, 0x4b, 0x29, 0x39, 0x2f, 0x38, 0x45, 0x36, 0x40, 0x4e, + 0x39, 0x3f, 0x48, 0x43, 0x23, 0x3c, 0x06, 0x51, 0x37, 0x3b, 0x3e, 0x3b, + 0x28, 0x45, 0x2b, 0x37, 0x3f, 0x33, 0x3f, 0x41, 0x31, 0x36, 0x33, 0x3a, + 0x3a, 0x35, 0x3b, 0x33, 0x3e, 0x36, 0x35, 0x40, 0x3a, 0x34, 0x3a, 0x38, + 0x34, 0x3a, 0x3a, 0x34, 0x42, 0x45, 0x40, 0x3e, 0x40, 0x38, 0x39, 0x34, + 0x38, 0x37, 0x3f, 0x3e, 0x3c, 0x32, 0x3f, 0x46, 0x3f, 0x44, 0x3b, 0x3e, + 0x44, 0x45, 0x36, 0x3e, 0x36, 0x3f, 0x3b, 0x40, 0x39, 0x34, 0x38, 0x41, + 0x42, 0x3e, 0x3d, 0x47, 0x3e, 0x45, 0x33, 0x40, 0x3e, 0x3a, 0x44, 0x3d, + 0x3c, 0x3a, 0x3a, 0x2c, 0x3a, 0x3d, 0x35, 0x45, 0x3c, 0x41, 0x36, 0x30, + 0x32, 0x32, 0x3a, 0x3b, 0x35, 0x3c, 0x43, 0x2d, 0x35, 0x3f, 0x41, 0x37, + 0x3f, 0x46, 0x34, 0x39, 0x3c, 0x43, 0x40, 0x3e, 0x3e, 0x36, 0x3e, 0x3c, + 0x37, 0x3a, 0x3d, 0x3a, 0x3c, 0x38, 0x44, 0x41, 0x3f, 0x3b, 0x3c, 0x47, + 0x40, 0x3b, 0x41, 0x47, 0x3e, 0x45, 0x39, 0x3e, 0x37, 0x45, 0x4b, 0x4c, + 0x37, 0x37, 0x37, 0x3c, 0x3c, 0x3d, 0x40, 0x38, 0x39, 0x3e, 0x43, 0x3f, + 0x38, 0x45, 0x51, 0x3c, 0x31, 0x34, 0x3b, 0x48, 0x46, 0x41, 0x40, 0x40, + 0x2c, 0x39, 0x32, 0x42, 0x3c, 0x2e, 0x49, 0x4d, 0x3c, 0x3f, 0x45, 0x38, + 0x20, 0x38, 0x03, 0x55, 0x33, 0x3e, 0x32, 0x39, 0x32, 0x3b, 0x24, 0x2b, + 0x42, 0x35, 0x45, 0x32, 0x2e, 0x3b, 0x2f, 0x3f, 0x3c, 0x37, 0x39, 0x3b, + 0x34, 0x34, 0x3d, 0x36, 0x3d, 0x39, 0x3b, 0x30, 0x3c, 0x3e, 0x40, 0x32, + 0x3d, 0x3c, 0x3c, 0x3e, 0x33, 0x33, 0x3f, 0x3a, 0x33, 0x3e, 0x46, 0x36, + 0x3a, 0x3d, 0x40, 0x40, 0x3f, 0x41, 0x3a, 0x42, 0x34, 0x32, 0x34, 0x46, + 0x3b, 0x31, 0x40, 0x37, 0x37, 0x32, 0x3e, 0x47, 0x3f, 0x3b, 0x3e, 0x43, + 0x49, 0x45, 0x3a, 0x3d, 0x3e, 0x44, 0x40, 0x31, 0x39, 0x3e, 0x3b, 0x2d, + 0x3b, 0x3a, 0x33, 0x3d, 0x39, 0x37, 0x3e, 0x32, 0x41, 0x3c, 0x3a, 0x37, + 0x3b, 0x40, 0x39, 0x2f, 0x3e, 0x3f, 0x47, 0x32, 0x3e, 0x3b, 0x3e, 0x3e, + 0x40, 0x3e, 0x40, 0x3c, 0x41, 0x39, 0x38, 0x46, 0x45, 0x32, 0x47, 0x31, + 0x36, 0x47, 0x37, 0x49, 0x3a, 0x3f, 0x47, 0x3a, 0x41, 0x3b, 0x3c, 0x4f, + 0x3e, 0x36, 0x3b, 0x47, 0x35, 0x39, 0x41, 0x4e, 0x3d, 0x3e, 0x3b, 0x46, + 0x38, 0x39, 0x3b, 0x45, 0x3e, 0x3f, 0x44, 0x42, 0x44, 0x3f, 0x55, 0x3b, + 0x41, 0x3d, 0x43, 0x43, 0x37, 0x3f, 0x3d, 0x4c, 0x28, 0x3d, 0x36, 0x3c, + 0x3e, 0x3e, 0x48, 0x50, 0x3e, 0x39, 0x45, 0x41, 0x22, 0x37, 0x07, 0x4f, + 0x2e, 0x33, 0x38, 0x3f, 0x31, 0x3a, 0x1b, 0x36, 0x34, 0x38, 0x3c, 0x37, + 0x37, 0x3e, 0x36, 0x35, 0x36, 0x3b, 0x3d, 0x38, 0x42, 0x48, 0x3d, 0x40, + 0x40, 0x44, 0x3d, 0x39, 0x37, 0x3b, 0x3d, 0x33, 0x3d, 0x35, 0x42, 0x3c, + 0x39, 0x3e, 0x43, 0x2d, 0x3c, 0x40, 0x43, 0x43, 0x45, 0x35, 0x3c, 0x44, + 0x34, 0x3c, 0x3d, 0x31, 0x39, 0x40, 0x39, 0x3d, 0x3e, 0x34, 0x3e, 0x3b, + 0x40, 0x38, 0x42, 0x4a, 0x40, 0x3b, 0x35, 0x3d, 0x36, 0x38, 0x35, 0x42, + 0x3c, 0x3c, 0x3d, 0x3b, 0x38, 0x39, 0x45, 0x28, 0x3a, 0x37, 0x37, 0x35, + 0x3a, 0x3d, 0x35, 0x2a, 0x3c, 0x3f, 0x37, 0x34, 0x37, 0x3f, 0x3e, 0x2b, + 0x39, 0x43, 0x3b, 0x45, 0x35, 0x36, 0x36, 0x42, 0x33, 0x38, 0x3b, 0x35, + 0x31, 0x3f, 0x41, 0x41, 0x3c, 0x41, 0x45, 0x42, 0x3b, 0x3c, 0x39, 0x46, + 0x3c, 0x3e, 0x3a, 0x41, 0x39, 0x3d, 0x41, 0x4b, 0x40, 0x3f, 0x43, 0x3d, + 0x39, 0x39, 0x44, 0x44, 0x37, 0x42, 0x3f, 0x44, 0x3e, 0x37, 0x42, 0x35, + 0x44, 0x3f, 0x40, 0x42, 0x3f, 0x3a, 0x47, 0x3d, 0x38, 0x3a, 0x3b, 0x3a, + 0x42, 0x36, 0x3a, 0x97, 0x32, 0x31, 0x30, 0x36, 0x47, 0x3e, 0x46, 0x51, + 0x42, 0x34, 0x50, 0x34, 0x26, 0x3b, 0x06, 0x55, 0x3c, 0x3b, 0x2d, 0x3a, + 0x37, 0x37, 0x1b, 0x32, 0x39, 0x3d, 0x36, 0x40, 0x3b, 0x3f, 0x33, 0x33, + 0x3d, 0x37, 0x35, 0x37, 0x44, 0x3f, 0x35, 0x39, 0x33, 0x3c, 0x43, 0x39, + 0x3f, 0x42, 0x3e, 0x34, 0x38, 0x38, 0x39, 0x3c, 0x48, 0x3c, 0x2f, 0x30, + 0x40, 0x3c, 0x41, 0x3e, 0x3f, 0x3e, 0x36, 0x43, 0x40, 0x3c, 0x36, 0x43, + 0x43, 0x38, 0x3a, 0x47, 0x3e, 0x37, 0x39, 0x3a, 0x43, 0x45, 0x38, 0x43, + 0x3b, 0x45, 0x37, 0x44, 0x36, 0x45, 0x3a, 0x3e, 0x3e, 0x3e, 0x3d, 0x33, + 0x39, 0x36, 0x48, 0x33, 0x30, 0x42, 0x33, 0x39, 0x37, 0x3a, 0x3f, 0x34, + 0x34, 0x40, 0x40, 0x40, 0x3f, 0x3d, 0x3f, 0x33, 0x41, 0x40, 0x3b, 0x43, + 0x3b, 0x3a, 0x40, 0x3a, 0x38, 0x3e, 0x38, 0x3b, 0x38, 0x42, 0x40, 0x40, + 0x41, 0x35, 0x37, 0x38, 0x3b, 0x3c, 0x39, 0x4b, 0x32, 0x39, 0x42, 0x3c, + 0x36, 0x3d, 0x32, 0x52, 0x3a, 0x31, 0x40, 0x40, 0x3a, 0x43, 0x3d, 0x46, + 0x3c, 0x3e, 0x3e, 0x33, 0x3f, 0x41, 0x4d, 0x37, 0x39, 0x39, 0x3e, 0x3b, + 0x40, 0x39, 0x53, 0x2d, 0x46, 0x3c, 0x32, 0x42, 0x3d, 0x40, 0x40, 0x4d, + 0x2e, 0x34, 0x39, 0x3b, 0x46, 0x3b, 0x42, 0x4f, 0x3d, 0x39, 0x4e, 0x36, + 0x1a, 0x31, 0x0e, 0x56, 0x36, 0x42, 0x38, 0x44, 0x36, 0x3a, 0x20, 0x30, + 0x36, 0x34, 0x37, 0x38, 0x40, 0x41, 0x2a, 0x35, 0x3b, 0x3b, 0x3a, 0x38, + 0x33, 0x39, 0x36, 0x41, 0x43, 0x39, 0x35, 0x3d, 0x37, 0x3d, 0x33, 0x31, + 0x45, 0x33, 0x3f, 0x3b, 0x44, 0x38, 0x39, 0x34, 0x38, 0x39, 0x38, 0x3d, + 0x3a, 0x3a, 0x41, 0x40, 0x44, 0x3e, 0x3f, 0x45, 0x34, 0x31, 0x34, 0x43, + 0x3b, 0x34, 0x42, 0x3c, 0x3c, 0x43, 0x35, 0x45, 0x36, 0x38, 0x3d, 0x3c, + 0x3f, 0x3d, 0x3e, 0x45, 0x41, 0x43, 0x35, 0x3f, 0x40, 0x3f, 0x3a, 0x34, + 0x3d, 0x32, 0x41, 0x3d, 0x48, 0x42, 0x37, 0x2a, 0x3c, 0x3a, 0x3e, 0x49, + 0x38, 0x36, 0x38, 0x2e, 0x36, 0x37, 0x34, 0x3e, 0x3c, 0x43, 0x43, 0x39, + 0x39, 0x3b, 0x44, 0x46, 0x44, 0x43, 0x37, 0x46, 0x43, 0x34, 0x3b, 0x35, + 0x42, 0x41, 0x3f, 0x3d, 0x3d, 0x3a, 0x42, 0x3e, 0x38, 0x47, 0x3d, 0x49, + 0x45, 0x49, 0x3a, 0x3c, 0x3e, 0x37, 0x40, 0x46, 0x41, 0x33, 0x45, 0x36, + 0x37, 0x44, 0x49, 0x3b, 0x44, 0x40, 0x33, 0x46, 0x37, 0x39, 0x4e, 0x3a, + 0x43, 0x38, 0x3a, 0x42, 0x3a, 0x3d, 0x45, 0x50, 0x26, 0x34, 0x3b, 0x3c, + 0x46, 0x46, 0x4c, 0x54, 0x3f, 0x35, 0x4e, 0x47, 0x21, 0x39, 0x0e, 0x54, + 0x3a, 0x3a, 0x2f, 0x40, 0x2d, 0x3a, 0x1f, 0x31, 0x31, 0x42, 0x34, 0x45, + 0x37, 0x36, 0x30, 0x3b, 0x3a, 0x3a, 0x36, 0x40, 0x32, 0x36, 0x3c, 0x3c, + 0x37, 0x42, 0x35, 0x3e, 0x39, 0x47, 0x36, 0x32, 0x41, 0x30, 0x42, 0x39, + 0x39, 0x44, 0x37, 0x30, 0x41, 0x3b, 0x3d, 0x3d, 0x43, 0x3b, 0x38, 0x45, + 0x3b, 0x3a, 0x39, 0x3a, 0x31, 0x33, 0x43, 0x46, 0x3f, 0x41, 0x44, 0x3f, + 0x3b, 0x44, 0x3a, 0x4c, 0x33, 0x33, 0x33, 0x3e, 0x37, 0x3e, 0x45, 0x45, + 0x36, 0x42, 0x3e, 0x43, 0x40, 0x34, 0x36, 0x31, 0x38, 0x34, 0x41, 0x3b, + 0x32, 0x38, 0x3e, 0x29, 0x47, 0x33, 0x37, 0x45, 0x3c, 0x3d, 0x43, 0x2c, + 0x36, 0x3a, 0x3c, 0x40, 0x3d, 0x46, 0x3c, 0x37, 0x40, 0x44, 0x37, 0x38, + 0x3e, 0x41, 0x3c, 0x40, 0x33, 0x3f, 0x44, 0x32, 0x44, 0x3a, 0x43, 0x42, + 0x3e, 0x38, 0x44, 0x3b, 0x41, 0x48, 0x3f, 0x4e, 0x3f, 0x44, 0x35, 0x45, + 0x34, 0x3f, 0x42, 0x4b, 0x37, 0x37, 0x3e, 0x45, 0x46, 0x45, 0x46, 0x3d, + 0x3e, 0x39, 0x3b, 0x3a, 0x46, 0x3a, 0x56, 0x35, 0x46, 0x3d, 0x40, 0x3b, + 0x36, 0x39, 0x3f, 0x54, 0x27, 0x2b, 0x34, 0x3c, 0x48, 0x3d, 0x49, 0x4c, + 0x3e, 0x3d, 0x4e, 0x42, 0x25, 0x3b, 0x10, 0x4d, 0x30, 0x36, 0x3e, 0x36, + 0x2e, 0x31, 0x1d, 0x37, 0x3a, 0x39, 0x33, 0x3f, 0x39, 0x38, 0x2e, 0x36, + 0x44, 0x3e, 0x41, 0x37, 0x3b, 0x30, 0x3b, 0x48, 0x31, 0x39, 0x41, 0x3e, + 0x37, 0x37, 0x34, 0x2f, 0x35, 0x3b, 0x3a, 0x3e, 0x45, 0x3e, 0x3f, 0x35, + 0x39, 0x39, 0x3b, 0x44, 0x43, 0x3c, 0x3e, 0x46, 0x40, 0x3a, 0x36, 0x45, + 0x41, 0x40, 0x36, 0x44, 0x3a, 0x37, 0x47, 0x47, 0x3d, 0x36, 0x43, 0x4e, + 0x3b, 0x38, 0x40, 0x48, 0x44, 0x43, 0x45, 0x3f, 0x43, 0x3c, 0x3b, 0x37, + 0x43, 0x41, 0x39, 0x2f, 0x3d, 0x45, 0x3e, 0x3e, 0x42, 0x40, 0x41, 0x2f, + 0x47, 0x38, 0x3a, 0x48, 0x3e, 0x35, 0x37, 0x2a, 0x34, 0x38, 0x41, 0x3b, + 0x3d, 0x37, 0x3b, 0x35, 0x38, 0x3e, 0x41, 0x3c, 0x41, 0x43, 0x3d, 0x46, + 0x47, 0x47, 0x3d, 0x35, 0x48, 0x41, 0x3d, 0x3e, 0x34, 0x47, 0x38, 0x38, + 0x39, 0x3e, 0x38, 0x4d, 0x43, 0x36, 0x42, 0x40, 0x3e, 0x41, 0x3f, 0x4c, + 0x3e, 0x3e, 0x37, 0x44, 0x3e, 0x3b, 0x47, 0x3e, 0x3f, 0x3b, 0x39, 0x3c, + 0x3c, 0x3c, 0x53, 0x3b, 0x3b, 0x32, 0x3e, 0x3f, 0x32, 0x3c, 0x37, 0x4b, + 0x33, 0x30, 0x2f, 0x41, 0x47, 0x42, 0x49, 0x4f, 0x3b, 0x42, 0x4c, 0x44, + 0x1f, 0x37, 0x16, 0x4e, 0x3b, 0x3f, 0x30, 0x36, 0x35, 0x38, 0x26, 0x36, + 0x32, 0x3b, 0x38, 0x3c, 0x30, 0x3e, 0x34, 0x3e, 0x3d, 0x34, 0x39, 0x3c, + 0x36, 0x47, 0x34, 0x41, 0x31, 0x39, 0x44, 0x3e, 0x39, 0x41, 0x32, 0x36, + 0x3b, 0x3f, 0x32, 0x3d, 0x36, 0x3e, 0x40, 0x3d, 0x45, 0x32, 0x45, 0x42, + 0x38, 0x43, 0x40, 0x42, 0x34, 0x3a, 0x43, 0x38, 0x47, 0x3f, 0x41, 0x47, + 0x34, 0x44, 0x41, 0x39, 0x3c, 0x46, 0x36, 0x4f, 0x41, 0x3e, 0x38, 0x38, + 0x3a, 0x3b, 0x43, 0x44, 0x37, 0x3f, 0x35, 0x43, 0x34, 0x3d, 0x40, 0x32, + 0x3a, 0x3b, 0x3d, 0x34, 0x35, 0x43, 0x31, 0x2c, 0x3b, 0x36, 0x38, 0x41, + 0x3c, 0x38, 0x3d, 0x31, 0x45, 0x46, 0x42, 0x41, 0x33, 0x3f, 0x3f, 0x3a, + 0x36, 0x3f, 0x3c, 0x3c, 0x3c, 0x3e, 0x39, 0x3e, 0x40, 0x37, 0x47, 0x3e, + 0x35, 0x39, 0x3d, 0x3d, 0x37, 0x36, 0x3e, 0x45, 0x38, 0x3d, 0x45, 0x43, + 0x3a, 0x32, 0x3b, 0x3a, 0x32, 0x3c, 0x3d, 0x43, 0x3d, 0x33, 0x3b, 0x3d, + 0x46, 0x3a, 0x44, 0x45, 0x3b, 0x3e, 0x3c, 0x42, 0x37, 0x37, 0x52, 0x2a, + 0x3a, 0x35, 0x35, 0x3f, 0x40, 0x38, 0x40, 0x5b, 0x35, 0x32, 0x2b, 0x3d, + 0x4a, 0x3c, 0x46, 0x56, 0x44, 0x30, 0x4d, 0x39, 0x20, 0x32, 0x0f, 0x4f, + 0x33, 0x3c, 0x35, 0x35, 0x3a, 0x45, 0x29, 0x3b, 0x31, 0x38, 0x34, 0x38, + 0x42, 0x45, 0x37, 0x3e, 0x37, 0x2e, 0x36, 0x43, 0x3f, 0x38, 0x2f, 0x41, + 0x3f, 0x41, 0x3c, 0x31, 0x37, 0x36, 0x37, 0x39, 0x41, 0x3a, 0x3a, 0x40, + 0x3e, 0x47, 0x3d, 0x37, 0x3c, 0x38, 0x35, 0x39, 0x3a, 0x43, 0x3f, 0x42, + 0x42, 0x38, 0x3e, 0x40, 0x3c, 0x3a, 0x45, 0x48, 0x37, 0x3a, 0x3e, 0x35, + 0x3a, 0x3d, 0x45, 0x4a, 0x3d, 0x37, 0x38, 0x3a, 0x3d, 0x46, 0x46, 0x41, + 0x37, 0x41, 0x40, 0x48, 0x37, 0x34, 0x3b, 0x2c, 0x39, 0x34, 0x37, 0x35, + 0x3a, 0x43, 0x39, 0x2e, 0x39, 0x3f, 0x40, 0x3e, 0x40, 0x40, 0x3c, 0x2d, + 0x3e, 0x3c, 0x37, 0x39, 0x3c, 0x3b, 0x3d, 0x3f, 0x41, 0x48, 0x3b, 0x3d, + 0x3b, 0x41, 0x45, 0x3e, 0x3a, 0x38, 0x3f, 0x3c, 0x3d, 0x3e, 0x40, 0x42, + 0x46, 0x38, 0x43, 0x34, 0x35, 0x47, 0x3d, 0x46, 0x3f, 0x3e, 0x32, 0x3f, + 0x3e, 0x3d, 0x47, 0x46, 0x38, 0x41, 0x45, 0x3f, 0x34, 0x3f, 0x41, 0x43, + 0x3e, 0x3e, 0x44, 0x3b, 0x3b, 0x36, 0x51, 0x32, 0x37, 0x3c, 0x42, 0x43, + 0x33, 0x39, 0x42, 0x61, 0x2c, 0x3b, 0x2e, 0x39, 0x42, 0x39, 0x42, 0x54, + 0x3c, 0x3a, 0x48, 0x35, 0x26, 0x34, 0x15, 0x51, 0x35, 0x40, 0x36, 0x3c, + 0x2d, 0x37, 0x25, 0x38, 0x33, 0x3d, 0x3d, 0x39, 0x3e, 0x3b, 0x2e, 0x4b, + 0x3d, 0x3b, 0x42, 0x37, 0x37, 0x40, 0x37, 0x40, 0x35, 0x45, 0x37, 0x37, + 0x3f, 0x41, 0x36, 0x39, 0x3c, 0x32, 0x3e, 0x38, 0x41, 0x40, 0x3e, 0x3f, + 0x3b, 0x3c, 0x43, 0x35, 0x3e, 0x3d, 0x44, 0x44, 0x3a, 0x36, 0x39, 0x3f, + 0x3a, 0x31, 0x42, 0x4d, 0x40, 0x33, 0x40, 0x45, 0x44, 0x3d, 0x40, 0x49, + 0x41, 0x3f, 0x42, 0x3a, 0x34, 0x46, 0x38, 0x46, 0x42, 0x34, 0x3a, 0x40, + 0x40, 0x41, 0x3d, 0x32, 0x35, 0x48, 0x35, 0x3e, 0x44, 0x41, 0x40, 0x2c, + 0x46, 0x38, 0x38, 0x3f, 0x36, 0x40, 0x38, 0x2a, 0x43, 0x41, 0x3e, 0x35, + 0x46, 0x3a, 0x45, 0x46, 0x46, 0x42, 0x3a, 0x3b, 0x40, 0x38, 0x35, 0x43, + 0x38, 0x3d, 0x3b, 0x41, 0x36, 0x44, 0x3f, 0x3f, 0x34, 0x3e, 0x3c, 0x3d, + 0x49, 0x36, 0x37, 0x4b, 0x38, 0x3c, 0x43, 0x37, 0x3a, 0x3f, 0x31, 0x45, + 0x3b, 0x39, 0x3f, 0x40, 0x37, 0x3c, 0x42, 0x3f, 0x3c, 0x33, 0x40, 0x3b, + 0x32, 0x3c, 0x52, 0x31, 0x3d, 0x44, 0x3b, 0x31, 0x46, 0x38, 0x40, 0x60, + 0x2b, 0x3c, 0x37, 0x34, 0x43, 0x38, 0x45, 0x57, 0x37, 0x39, 0x49, 0x33, + 0x2d, 0x3f, 0x18, 0x4e, 0x39, 0x39, 0x32, 0x3b, 0x34, 0x3b, 0x2c, 0x45, + 0x33, 0x37, 0x45, 0x42, 0x3d, 0x37, 0x2a, 0x4c, 0x3d, 0x3f, 0x3c, 0x36, + 0x37, 0x3c, 0x39, 0x47, 0x3d, 0x44, 0x3d, 0x40, 0x3d, 0x41, 0x34, 0x3e, + 0x40, 0x34, 0x3b, 0x3a, 0x41, 0x36, 0x37, 0x40, 0x3e, 0x3f, 0x3a, 0x36, + 0x3e, 0x35, 0x3b, 0x48, 0x41, 0x40, 0x3c, 0x42, 0x34, 0x41, 0x3f, 0x44, + 0x34, 0x39, 0x33, 0x39, 0x39, 0x47, 0x40, 0x48, 0x38, 0x3a, 0x43, 0x43, + 0x48, 0x3a, 0x3f, 0x46, 0x35, 0x3a, 0x33, 0x36, 0x32, 0x3c, 0x40, 0x34, + 0x40, 0x3a, 0x42, 0x3a, 0x39, 0x38, 0x41, 0x35, 0x3a, 0x3f, 0x35, 0x40, + 0x3f, 0x39, 0x39, 0x36, 0x38, 0x40, 0x3e, 0x3e, 0x3a, 0x31, 0x32, 0x44, + 0x40, 0x47, 0x3a, 0x3c, 0x43, 0x43, 0x46, 0x48, 0x40, 0x35, 0x3d, 0x37, + 0x44, 0x37, 0x33, 0x44, 0x3b, 0x3e, 0x3f, 0x37, 0x36, 0x3a, 0x38, 0x47, + 0x3a, 0x44, 0x36, 0x42, 0x3e, 0x44, 0x34, 0x46, 0x33, 0x43, 0x44, 0x3e, + 0x30, 0x48, 0x37, 0x38, 0x33, 0x3c, 0x46, 0x42, 0x38, 0x3d, 0x50, 0x39, + 0x33, 0x38, 0x3e, 0x40, 0x3b, 0x2b, 0x3b, 0x5f, 0x2b, 0x32, 0x2f, 0x37, + 0x3f, 0x3a, 0x40, 0x4e, 0x34, 0x38, 0x47, 0x37, 0x27, 0x2b, 0x1b, 0x4f, + 0x36, 0x38, 0x3a, 0x3a, 0x3b, 0x38, 0x2e, 0x3f, 0x3f, 0x42, 0x42, 0x42, + 0x36, 0x3e, 0x3c, 0x55, 0x39, 0x40, 0x44, 0x43, 0x3e, 0x33, 0x3c, 0x43, + 0x38, 0x44, 0x3b, 0x46, 0x3f, 0x45, 0x34, 0x38, 0x3c, 0x41, 0x42, 0x3d, + 0x42, 0x36, 0x43, 0x3f, 0x3c, 0x39, 0x3e, 0x39, 0x39, 0x42, 0x33, 0x47, + 0x36, 0x3d, 0x3f, 0x3b, 0x40, 0x39, 0x3b, 0x49, 0x36, 0x40, 0x3d, 0x41, + 0x40, 0x34, 0x3b, 0x4e, 0x3b, 0x36, 0x3b, 0x45, 0x40, 0x32, 0x3b, 0x49, + 0x37, 0x38, 0x3a, 0x47, 0x37, 0x40, 0x3e, 0x38, 0x40, 0x3f, 0x3c, 0x3a, + 0x47, 0x41, 0x42, 0x30, 0x40, 0x3c, 0x42, 0x3f, 0x31, 0x44, 0x39, 0x38, + 0x3b, 0x38, 0x42, 0x43, 0x41, 0x35, 0x3a, 0x39, 0x3e, 0x38, 0x39, 0x3e, + 0x3c, 0x42, 0x3d, 0x49, 0x47, 0x3c, 0x3f, 0x35, 0x41, 0x3a, 0x36, 0x43, + 0x43, 0x3b, 0x39, 0x3b, 0x36, 0x43, 0x43, 0x4e, 0x3e, 0x35, 0x37, 0x3b, + 0x3f, 0x37, 0x41, 0x48, 0x32, 0x44, 0x43, 0x32, 0x38, 0x39, 0x45, 0x39, + 0x3e, 0x3d, 0x35, 0x39, 0x35, 0x39, 0x50, 0x37, 0x39, 0x40, 0x43, 0x47, + 0x32, 0x2a, 0x40, 0x62, 0x24, 0x30, 0x36, 0x3e, 0x41, 0x32, 0x47, 0x58, + 0x39, 0x36, 0x44, 0x34, 0x26, 0x34, 0x1e, 0x50, 0x3c, 0x3b, 0x3f, 0x42, + 0x35, 0x3d, 0x2a, 0x4e, 0x40, 0x38, 0x36, 0x31, 0x3a, 0x30, 0x37, 0x4b, + 0x3c, 0x3b, 0x3b, 0x41, 0x3b, 0x3c, 0x2e, 0x45, 0x44, 0x3f, 0x3b, 0x35, + 0x3e, 0x33, 0x37, 0x3d, 0x40, 0x39, 0x39, 0x37, 0x40, 0x3e, 0x3a, 0x3e, + 0x3c, 0x3c, 0x45, 0x40, 0x3c, 0x3f, 0x3a, 0x51, 0x47, 0x3a, 0x34, 0x39, + 0x3b, 0x34, 0x44, 0x4c, 0x36, 0x3d, 0x3a, 0x35, 0x34, 0x36, 0x38, 0x4b, + 0x3f, 0x40, 0x3f, 0x3e, 0x40, 0x41, 0x47, 0x43, 0x32, 0x38, 0x46, 0x44, + 0x46, 0x43, 0x43, 0x37, 0x39, 0x49, 0x37, 0x36, 0x3e, 0x3d, 0x37, 0x3c, + 0x39, 0x37, 0x34, 0x43, 0x45, 0x32, 0x3a, 0x3a, 0x38, 0x43, 0x3b, 0x40, + 0x3b, 0x3f, 0x3d, 0x41, 0x40, 0x3d, 0x3a, 0x3b, 0x48, 0x37, 0x3d, 0x41, + 0x40, 0x3e, 0x38, 0x41, 0x3d, 0x3a, 0x38, 0x49, 0x40, 0x3c, 0x42, 0x41, + 0x3a, 0x38, 0x38, 0x4c, 0x3e, 0x41, 0x40, 0x3b, 0x3d, 0x3e, 0x3c, 0x46, + 0x3e, 0x42, 0x41, 0x38, 0x42, 0x42, 0x41, 0x3e, 0x3e, 0x37, 0x3c, 0x43, + 0x43, 0x3b, 0x54, 0x2b, 0x45, 0x3b, 0x43, 0x41, 0x41, 0x26, 0x3f, 0x60, + 0x25, 0x2b, 0x2e, 0x3a, 0x40, 0x31, 0x40, 0x49, 0x40, 0x31, 0x46, 0x3c, + 0x1e, 0x2a, 0x1a, 0x47, 0x33, 0x37, 0x37, 0x34, 0x31, 0x36, 0x25, 0x41, + 0x2e, 0x36, 0x35, 0x33, 0x33, 0x34, 0x31, 0x45, 0x3a, 0x3f, 0x3d, 0x40, + 0x3c, 0x41, 0x30, 0x3c, 0x3f, 0x46, 0x37, 0x3c, 0x3a, 0x3c, 0x36, 0x3a, + 0x47, 0x3d, 0x31, 0x3f, 0x40, 0x3e, 0x36, 0x44, 0x41, 0x3d, 0x36, 0x3f, + 0x37, 0x3f, 0x34, 0x4b, 0x31, 0x47, 0x43, 0x3e, 0x3e, 0x3a, 0x3b, 0x4b, + 0x37, 0x32, 0x38, 0x3d, 0x37, 0x47, 0x46, 0x4d, 0x36, 0x3c, 0x3f, 0x3a, + 0x41, 0x31, 0x47, 0x43, 0x3d, 0x3d, 0x3e, 0x35, 0x3d, 0x46, 0x49, 0x2a, + 0x37, 0x3c, 0x39, 0x3d, 0x47, 0x3c, 0x34, 0x2c, 0x3e, 0x38, 0x47, 0x32, + 0x36, 0x36, 0x41, 0x38, 0x35, 0x44, 0x48, 0x3b, 0x39, 0x3e, 0x38, 0x3e, + 0x40, 0x36, 0x37, 0x46, 0x39, 0x3b, 0x34, 0x45, 0x40, 0x3b, 0x48, 0x36, + 0x34, 0x44, 0x37, 0x46, 0x3f, 0x42, 0x33, 0x36, 0x43, 0x3c, 0x41, 0x46, + 0x31, 0x42, 0x43, 0x44, 0x44, 0x3e, 0x42, 0x3b, 0x3b, 0x3a, 0x3c, 0x37, + 0x42, 0x41, 0x46, 0x38, 0x41, 0x3b, 0x40, 0x44, 0x37, 0x3c, 0x4c, 0x2e, + 0x3a, 0x3e, 0x3b, 0x36, 0x33, 0x27, 0x37, 0x5d, 0x27, 0x34, 0x32, 0x41, + 0x41, 0x3f, 0x40, 0x5d, 0x40, 0x3d, 0x48, 0x39, 0x2e, 0x30, 0x1f, 0x3f, + 0x38, 0x3f, 0x40, 0x33, 0x40, 0x38, 0x31, 0x3f, 0x42, 0x3e, 0x3b, 0x3a, + 0x42, 0x36, 0x3a, 0x42, 0x3c, 0x3b, 0x3d, 0x41, 0x3d, 0x40, 0x40, 0x3e, + 0x36, 0x41, 0x47, 0x3d, 0x33, 0x32, 0x33, 0x44, 0x3e, 0x3a, 0x3e, 0x3d, + 0x45, 0x3f, 0x38, 0x3f, 0x40, 0x3a, 0x3c, 0x46, 0x32, 0x42, 0x3c, 0x51, + 0x33, 0x38, 0x3a, 0x38, 0x41, 0x34, 0x45, 0x4e, 0x35, 0x3c, 0x42, 0x3e, + 0x3f, 0x45, 0x44, 0x4e, 0x39, 0x47, 0x3a, 0x33, 0x3e, 0x3b, 0x45, 0x42, + 0x37, 0x3a, 0x3e, 0x33, 0x41, 0x48, 0x32, 0x2a, 0x3b, 0x37, 0x3f, 0x3d, + 0x3a, 0x42, 0x41, 0x2f, 0x34, 0x3e, 0x49, 0x3b, 0x38, 0x3e, 0x3d, 0x3a, + 0x37, 0x3c, 0x44, 0x41, 0x39, 0x42, 0x3f, 0x39, 0x40, 0x35, 0x3d, 0x41, + 0x3b, 0x45, 0x44, 0x48, 0x3d, 0x42, 0x36, 0x33, 0x3e, 0x44, 0x3f, 0x41, + 0x42, 0x40, 0x49, 0x34, 0x48, 0x41, 0x3f, 0x40, 0x3c, 0x45, 0x47, 0x34, + 0x41, 0x37, 0x47, 0x3e, 0x41, 0x41, 0x39, 0x42, 0x3f, 0x3a, 0x46, 0x33, + 0x39, 0x41, 0x38, 0x38, 0x3e, 0x42, 0x41, 0x38, 0x35, 0x32, 0x33, 0x38, + 0x3a, 0x3f, 0x45, 0x66, 0x33, 0x47, 0x38, 0x3c, 0x41, 0x2f, 0x48, 0x55, + 0x33, 0x3e, 0x49, 0x3b, 0x3c, 0x30, 0x24, 0x45, 0x3c, 0x44, 0x43, 0x32, + 0x3d, 0x3f, 0x35, 0x3b, 0x3e, 0x36, 0x38, 0x3a, 0x36, 0x37, 0x3b, 0x41, + 0x38, 0x42, 0x3e, 0x43, 0x39, 0x3f, 0x3c, 0x40, 0x37, 0x43, 0x3e, 0x3b, + 0x3d, 0x35, 0x35, 0x3d, 0x43, 0x3f, 0x3a, 0x35, 0x37, 0x3c, 0x31, 0x47, + 0x44, 0x45, 0x40, 0x32, 0x44, 0x36, 0x38, 0x51, 0x3c, 0x41, 0x45, 0x37, + 0x39, 0x44, 0x3e, 0x4f, 0x3c, 0x3a, 0x38, 0x40, 0x3f, 0x34, 0x39, 0x4e, + 0x3d, 0x39, 0x45, 0x3f, 0x3e, 0x3c, 0x3b, 0x42, 0x3b, 0x3b, 0x34, 0x3d, + 0x41, 0x44, 0x39, 0x2e, 0x37, 0x44, 0x45, 0x37, 0x3d, 0x41, 0x3f, 0x33, + 0x3f, 0x3e, 0x3e, 0x40, 0x44, 0x3f, 0x37, 0x32, 0x35, 0x3e, 0x43, 0x41, + 0x39, 0x37, 0x35, 0x3f, 0x48, 0x3d, 0x43, 0x49, 0x38, 0x35, 0x3f, 0x48, + 0x3b, 0x3a, 0x34, 0x3f, 0x3c, 0x44, 0x3a, 0x40, 0x36, 0x35, 0x44, 0x36, + 0x44, 0x3b, 0x3d, 0x38, 0x3c, 0x44, 0x47, 0x3a, 0x3b, 0x45, 0x41, 0x3a, + 0x39, 0x35, 0x44, 0x3a, 0x49, 0x36, 0x48, 0x31, 0x42, 0x43, 0x42, 0x34, + 0x41, 0x40, 0x4d, 0x36, 0x3e, 0x35, 0x39, 0x3b, 0x3f, 0x41, 0x38, 0x39, + 0x3c, 0x44, 0x3f, 0x39, 0x3a, 0x36, 0x3d, 0x36, 0x3a, 0x3a, 0x34, 0x3b, + 0x38, 0x2f, 0x40, 0x34, 0x32, 0x4d, 0x43, 0x45, 0x4e, 0x3f, 0x48, 0x35, + 0x3b, 0x4d, 0x4f, 0x39, 0x42, 0x36, 0x46, 0x36, 0x4a, 0x3c, 0x37, 0x41, + 0x40, 0x43, 0x50, 0x36, 0x3e, 0x39, 0x44, 0x40, 0x36, 0x47, 0x3f, 0x36, + 0x45, 0x40, 0x45, 0x41, 0x3b, 0x37, 0x41, 0x39, 0x3b, 0x48, 0x37, 0x34, + 0x41, 0x45, 0x49, 0x3f, 0x39, 0x49, 0x3f, 0x3a, 0x42, 0x34, 0x38, 0x37, + 0x44, 0x34, 0x3c, 0x3d, 0x40, 0x47, 0x3a, 0x36, 0x3f, 0x3c, 0x41, 0x3e, + 0x47, 0x46, 0x46, 0x43, 0x3f, 0x38, 0x3b, 0x40, 0x3f, 0x48, 0x3b, 0x4c, + 0x3d, 0x4b, 0x34, 0x3b, 0x44, 0x43, 0x3c, 0x49, 0x38, 0x42, 0x41, 0x36, + 0x33, 0x36, 0x40, 0x46, 0x40, 0x3a, 0x42, 0x3c, 0x3d, 0x35, 0x3c, 0x52, + 0x3e, 0x40, 0x43, 0x43, 0x41, 0x3b, 0x3e, 0x44, 0x3f, 0x40, 0x40, 0x43, + 0x3d, 0x3f, 0x36, 0x42, 0x3f, 0x3c, 0x34, 0x3d, 0x33, 0x41, 0x3c, 0x39, + 0x34, 0x43, 0x3f, 0x34, 0x3c, 0x3a, 0x3a, 0x37, 0x42, 0x41, 0x40, 0x3e, + 0x3d, 0x3c, 0x41, 0x3c, 0x38, 0x33, 0x49, 0x46, 0x40, 0x40, 0x3a, 0x46, + 0x38, 0x3c, 0x37, 0x34, 0x3e, 0x3d, 0x32, 0x38, 0x3c, 0x4c, 0x3a, 0x34, + 0x35, 0x32, 0x39, 0x40, 0x3a, 0x58, 0x40, 0x46, 0x42, 0x33, 0x45, 0x39, + 0x34, 0x4f, 0x53, 0x45, 0x43, 0x3e, 0x41, 0x36, 0x3e, 0x3f, 0x40, 0x47, + 0x4e, 0x3d, 0x53, 0x2b, 0x41, 0x36, 0x3e, 0x38, 0x47, 0x41, 0x3f, 0x34, + 0x47, 0x40, 0x38, 0x39, 0x3d, 0x42, 0x3f, 0x3c, 0x48, 0x3a, 0x35, 0x3c, + 0x45, 0x49, 0x3c, 0x33, 0x33, 0x3f, 0x3c, 0x46, 0x43, 0x3f, 0x45, 0x31, + 0x35, 0x43, 0x46, 0x3a, 0x45, 0x3c, 0x37, 0x3a, 0x37, 0x36, 0x35, 0x3f, + 0x38, 0x49, 0x34, 0x3f, 0x3c, 0x42, 0x49, 0x3e, 0x3e, 0x3c, 0x39, 0x49, + 0x3e, 0x3c, 0x3b, 0x43, 0x44, 0x45, 0x39, 0x4b, 0x47, 0x47, 0x3e, 0x33, + 0x3c, 0x31, 0x34, 0x4f, 0x45, 0x43, 0x40, 0x3d, 0x42, 0x3b, 0x43, 0x50, + 0x3c, 0x3b, 0x37, 0x42, 0x47, 0x42, 0x3e, 0x4a, 0x3f, 0x3a, 0x48, 0x3d, + 0x48, 0x45, 0x3e, 0x40, 0x3a, 0x3c, 0x3d, 0x39, 0x41, 0x42, 0x3c, 0x42, + 0x43, 0x3c, 0x3b, 0x3d, 0x47, 0x49, 0x38, 0x3c, 0x46, 0x3a, 0x3c, 0x3f, + 0x3a, 0x46, 0x3a, 0x3b, 0x3d, 0x3a, 0x49, 0x46, 0x38, 0x40, 0x3e, 0x38, + 0x37, 0x32, 0x40, 0x3c, 0x42, 0x3d, 0x3b, 0x40, 0x3a, 0x38, 0x49, 0x33, + 0x40, 0x38, 0x2b, 0x3a, 0x3c, 0x4f, 0x4d, 0x3e, 0x35, 0x3d, 0x3b, 0x40, + 0x3a, 0x54, 0x3e, 0x3e, 0x43, 0x30, 0x47, 0x3d, 0x3b, 0x53, 0x52, 0x4a, + 0x43, 0x41, 0x49, 0x37, 0x3b, 0x35, 0x44, 0x3c, 0x45, 0x40, 0x4f, 0x36, + 0x4b, 0x42, 0x41, 0x3a, 0x41, 0x44, 0x47, 0x32, 0x43, 0x35, 0x3f, 0x37, + 0x43, 0x41, 0x43, 0x36, 0x3f, 0x3b, 0x3d, 0x38, 0x3d, 0x40, 0x42, 0x36, + 0x44, 0x3a, 0x39, 0x47, 0x37, 0x34, 0x42, 0x3a, 0x37, 0x38, 0x37, 0x3f, + 0x36, 0x3b, 0x45, 0x3f, 0x3f, 0x3d, 0x39, 0x3d, 0x39, 0x41, 0x37, 0x3f, + 0x3f, 0x3d, 0x3f, 0x41, 0x43, 0x41, 0x45, 0x43, 0x41, 0x3c, 0x3e, 0x40, + 0x40, 0x39, 0x41, 0x4f, 0x47, 0x42, 0x46, 0x48, 0x3b, 0x3b, 0x3c, 0x46, + 0x47, 0x3e, 0x46, 0x37, 0x38, 0x3d, 0x38, 0x52, 0x36, 0x46, 0x3c, 0x3a, + 0x3b, 0x37, 0x48, 0x4b, 0x3f, 0x42, 0x3c, 0x36, 0x40, 0x37, 0x33, 0x4c, + 0x39, 0x34, 0x41, 0x34, 0x3f, 0x3b, 0x35, 0x4b, 0x3b, 0x45, 0x43, 0x31, + 0x3e, 0x39, 0x30, 0x3d, 0x32, 0x43, 0x44, 0x3c, 0x3e, 0x38, 0x43, 0x41, + 0x3e, 0x37, 0x41, 0x39, 0x39, 0x44, 0x43, 0x38, 0x3f, 0x37, 0x48, 0x3f, + 0x3b, 0x44, 0x37, 0x3f, 0x3a, 0x3f, 0x3b, 0x33, 0x42, 0x3e, 0x2f, 0x42, + 0x44, 0x4f, 0x52, 0x3c, 0x34, 0x33, 0x39, 0x46, 0x31, 0x55, 0x43, 0x4e, + 0x49, 0x38, 0x4d, 0x48, 0x34, 0x4d, 0x5c, 0x4d, 0x49, 0x37, 0x4f, 0x40, + 0x3c, 0x3d, 0x41, 0x42, 0x3f, 0x51, 0x4b, 0x2f, 0x46, 0x35, 0x39, 0x3c, + 0x49, 0x3d, 0x4e, 0x32, 0x43, 0x47, 0x31, 0x3e, 0x42, 0x4a, 0x4c, 0x39, + 0x43, 0x46, 0x3e, 0x3f, 0x44, 0x3c, 0x42, 0x30, 0x3e, 0x34, 0x3b, 0x3b, + 0x3a, 0x3c, 0x42, 0x3d, 0x3d, 0x48, 0x48, 0x36, 0x3a, 0x45, 0x38, 0x40, + 0x3c, 0x41, 0x3f, 0x49, 0x42, 0x41, 0x38, 0x3d, 0x3d, 0x44, 0x3b, 0x3d, + 0x35, 0x48, 0x43, 0x3b, 0x32, 0x41, 0x3e, 0x3a, 0x46, 0x41, 0x40, 0x54, + 0x38, 0x3f, 0x3c, 0x36, 0x3b, 0x36, 0x43, 0x50, 0x38, 0x3c, 0x44, 0x3b, + 0x43, 0x47, 0x32, 0x50, 0x3d, 0x46, 0x3d, 0x3b, 0x39, 0x37, 0x3b, 0x4a, + 0x47, 0x43, 0x46, 0x3d, 0x3d, 0x41, 0x43, 0x45, 0x3b, 0x3c, 0x39, 0x47, + 0x43, 0x42, 0x39, 0x4c, 0x34, 0x41, 0x45, 0x3b, 0x38, 0x3e, 0x37, 0x3f, + 0x45, 0x43, 0x39, 0x42, 0x3c, 0x3d, 0x3d, 0x3c, 0x48, 0x39, 0x3b, 0x3a, + 0x46, 0x45, 0x3d, 0x3a, 0x3f, 0x3a, 0x45, 0x36, 0x3d, 0x43, 0x36, 0x43, + 0x42, 0x3d, 0x41, 0x3f, 0x3a, 0x3f, 0x31, 0x37, 0x48, 0x4f, 0x4e, 0x36, + 0x30, 0x3a, 0x3e, 0x3e, 0x38, 0x57, 0x40, 0x47, 0x47, 0x38, 0x4f, 0x46, + 0x3d, 0x4a, 0x50, 0x4c, 0x42, 0x3b, 0x4d, 0x3d, 0x3d, 0x33, 0x40, 0x41, + 0x48, 0x4b, 0x46, 0x39, 0x4d, 0x30, 0x45, 0x38, 0x48, 0x3c, 0x48, 0x3b, + 0x4d, 0x40, 0x3b, 0x40, 0x46, 0x41, 0x51, 0x34, 0x40, 0x43, 0x3f, 0x42, + 0x45, 0x42, 0x3e, 0x35, 0x3d, 0x38, 0x37, 0x3a, 0x42, 0x40, 0x43, 0x3c, + 0x3c, 0x3d, 0x43, 0x40, 0x45, 0x3a, 0x3e, 0x3a, 0x3e, 0x40, 0x43, 0x35, + 0x37, 0x3f, 0x3f, 0x3e, 0x39, 0x3f, 0x47, 0x38, 0x3e, 0x44, 0x3b, 0x3c, + 0x3b, 0x32, 0x40, 0x3e, 0x42, 0x45, 0x3a, 0x52, 0x3a, 0x3e, 0x45, 0x40, + 0x41, 0x48, 0x3f, 0x4e, 0x3e, 0x42, 0x3d, 0x39, 0x3a, 0x33, 0x3f, 0x4b, + 0x3e, 0x38, 0x36, 0x3e, 0x31, 0x41, 0x3a, 0x40, 0x3b, 0x37, 0x3f, 0x3e, + 0x3e, 0x3f, 0x35, 0x44, 0x3d, 0x42, 0x3d, 0x44, 0x42, 0x3f, 0x3e, 0x44, + 0x3e, 0x45, 0x37, 0x3a, 0x3b, 0x42, 0x3f, 0x41, 0x3b, 0x3f, 0x41, 0x41, + 0x3e, 0x34, 0x47, 0x39, 0x46, 0x46, 0x37, 0x39, 0x3f, 0x45, 0x39, 0x39, + 0x3a, 0x40, 0x38, 0x3a, 0x31, 0x34, 0x3a, 0x41, 0x38, 0x41, 0x3a, 0x41, + 0x44, 0x37, 0x2d, 0x41, 0x43, 0x4d, 0x4b, 0x3b, 0x2c, 0x30, 0x42, 0x3b, + 0x31, 0x56, 0x43, 0x47, 0x47, 0x38, 0x50, 0x44, 0x40, 0x52, 0x5a, 0x50, + 0x44, 0x3f, 0x4b, 0x35, 0x3a, 0x36, 0x41, 0x44, 0x47, 0x4e, 0x52, 0x36, + 0x45, 0x39, 0x38, 0x3c, 0x42, 0x44, 0x40, 0x3b, 0x4b, 0x38, 0x35, 0x35, + 0x3f, 0x40, 0x4f, 0x39, 0x3d, 0x37, 0x34, 0x3e, 0x41, 0x4c, 0x40, 0x37, + 0x3d, 0x3b, 0x37, 0x37, 0x40, 0x42, 0x35, 0x39, 0x41, 0x42, 0x3d, 0x34, + 0x3c, 0x37, 0x3a, 0x3d, 0x46, 0x46, 0x46, 0x3f, 0x44, 0x3d, 0x3c, 0x40, + 0x3c, 0x3a, 0x3d, 0x3b, 0x3b, 0x41, 0x47, 0x3a, 0x43, 0x43, 0x43, 0x3b, + 0x3e, 0x3e, 0x42, 0x46, 0x36, 0x37, 0x45, 0x35, 0x3c, 0x3b, 0x31, 0x4b, + 0x3c, 0x3e, 0x3a, 0x3a, 0x42, 0x42, 0x34, 0x47, 0x37, 0x34, 0x41, 0x3d, + 0x3e, 0x39, 0x43, 0x47, 0x31, 0x3b, 0x40, 0x3b, 0x42, 0x3d, 0x44, 0x44, + 0x37, 0x39, 0x44, 0x3b, 0x40, 0x3a, 0x3d, 0x44, 0x3c, 0x40, 0x42, 0x3b, + 0x40, 0x3e, 0x32, 0x3d, 0x3c, 0x3e, 0x44, 0x3e, 0x47, 0x3d, 0x3f, 0x2e, + 0x3e, 0x3d, 0x3f, 0x3b, 0x3b, 0x43, 0x43, 0x3c, 0x3a, 0x3c, 0x3a, 0x36, + 0x38, 0x46, 0x30, 0x3e, 0x3f, 0x35, 0x3e, 0x34, 0x3c, 0x34, 0x32, 0x4a, + 0x41, 0x48, 0x48, 0x3f, 0x34, 0x37, 0x42, 0x43, 0x36, 0x59, 0x42, 0x3f, + 0x4b, 0x3d, 0x5d, 0x45, 0x3b, 0x51, 0x51, 0x4c, 0x41, 0x40, 0x4d, 0x36, + 0x3f, 0x34, 0x39, 0x3d, 0x4a, 0x4b, 0x4f, 0x33, 0x48, 0x32, 0x3c, 0x32, + 0x48, 0x4c, 0x4d, 0x3a, 0x49, 0x3a, 0x3a, 0x2e, 0x4b, 0x44, 0x4f, 0x33, + 0x3a, 0x48, 0x34, 0x43, 0x38, 0x45, 0x44, 0x35, 0x3b, 0x3f, 0x40, 0x37, + 0x35, 0x34, 0x38, 0x3e, 0x41, 0x3e, 0x3b, 0x47, 0x41, 0x47, 0x3c, 0x3c, + 0x39, 0x40, 0x3e, 0x45, 0x36, 0x41, 0x3f, 0x3f, 0x3c, 0x44, 0x3f, 0x43, + 0x3d, 0x3c, 0x49, 0x42, 0x3e, 0x3f, 0x48, 0x37, 0x43, 0x37, 0x43, 0x3d, + 0x32, 0x42, 0x44, 0x39, 0x36, 0x37, 0x40, 0x46, 0x47, 0x3d, 0x3a, 0x42, + 0x3f, 0x38, 0x37, 0x48, 0x39, 0x40, 0x3c, 0x37, 0x33, 0x38, 0x38, 0x40, + 0x41, 0x3c, 0x3f, 0x3b, 0x40, 0x3a, 0x47, 0x46, 0x3a, 0x37, 0x42, 0x47, + 0x3b, 0x3f, 0x3b, 0x40, 0x33, 0x3f, 0x3a, 0x3c, 0x38, 0x3a, 0x36, 0x38, + 0x36, 0x40, 0x48, 0x42, 0x48, 0x3c, 0x43, 0x36, 0x32, 0x3b, 0x34, 0x39, + 0x38, 0x46, 0x37, 0x3b, 0x44, 0x34, 0x36, 0x38, 0x3c, 0x43, 0x33, 0x3c, + 0x3b, 0x45, 0x38, 0x38, 0x44, 0x33, 0x36, 0x4a, 0x46, 0x4c, 0x4a, 0x34, + 0x36, 0x37, 0x43, 0x42, 0x33, 0x58, 0x43, 0x48, 0x44, 0x38, 0x5f, 0x3f, + 0x3c, 0x4d, 0x53, 0x52, 0x43, 0x47, 0x52, 0x3e, 0x3b, 0x2d, 0x3b, 0x3a, + 0x4b, 0x49, 0x53, 0x38, 0x4c, 0x2f, 0x38, 0x31, 0x42, 0x40, 0x48, 0x3f, + 0x44, 0x3c, 0x3c, 0x34, 0x46, 0x3f, 0x49, 0x3a, 0x43, 0x3d, 0x34, 0x42, + 0x36, 0x47, 0x51, 0x3c, 0x3d, 0x39, 0x39, 0x3a, 0x3b, 0x35, 0x35, 0x41, + 0x47, 0x3c, 0x3b, 0x43, 0x3f, 0x45, 0x3e, 0x40, 0x3c, 0x3f, 0x3c, 0x42, + 0x3b, 0x3e, 0x38, 0x3f, 0x3f, 0x41, 0x39, 0x39, 0x3d, 0x43, 0x4f, 0x3d, + 0x48, 0x3b, 0x44, 0x45, 0x3d, 0x3b, 0x49, 0x43, 0x44, 0x3d, 0x37, 0x3b, + 0x3c, 0x45, 0x46, 0x44, 0x35, 0x3e, 0x32, 0x35, 0x34, 0x3b, 0x40, 0x43, + 0x3e, 0x45, 0x37, 0x3d, 0x3f, 0x43, 0x36, 0x3f, 0x3f, 0x43, 0x39, 0x44, + 0x3e, 0x3e, 0x45, 0x40, 0x3e, 0x44, 0x3b, 0x3e, 0x42, 0x42, 0x3b, 0x3d, + 0x3a, 0x40, 0x39, 0x3a, 0x32, 0x36, 0x41, 0x30, 0x39, 0x46, 0x33, 0x3f, + 0x46, 0x40, 0x3c, 0x31, 0x41, 0x3a, 0x3f, 0x3f, 0x3b, 0x36, 0x3f, 0x38, + 0x36, 0x3e, 0x35, 0x35, 0x3b, 0x3d, 0x3f, 0x39, 0x46, 0x37, 0x3a, 0x47, + 0x37, 0x39, 0x2c, 0x55, 0x40, 0x4b, 0x4a, 0x39, 0x35, 0x42, 0x3d, 0x40, + 0x3a, 0x54, 0x41, 0x48, 0x51, 0x3b, 0x61, 0x3e, 0x3e, 0x4d, 0x51, 0x52, + 0x3e, 0x43, 0x52, 0x41, 0x48, 0x2d, 0x35, 0x35, 0x4b, 0x44, 0x4d, 0x3c, + 0x54, 0x33, 0x39, 0x27, 0x4a, 0x44, 0x4a, 0x41, 0x3c, 0x3a, 0x31, 0x2f, + 0x3d, 0x42, 0x48, 0x3f, 0x42, 0x40, 0x44, 0x3b, 0x40, 0x3e, 0x49, 0x3a, + 0x3c, 0x35, 0x30, 0x3e, 0x3e, 0x3d, 0x36, 0x3a, 0x3e, 0x3a, 0x4a, 0x3e, + 0x3d, 0x49, 0x40, 0x43, 0x3e, 0x45, 0x3f, 0x3c, 0x3b, 0x42, 0x3a, 0x39, + 0x3b, 0x47, 0x3f, 0x39, 0x49, 0x46, 0x3d, 0x34, 0x32, 0x44, 0x46, 0x42, + 0x47, 0x39, 0x49, 0x48, 0x3b, 0x38, 0x45, 0x45, 0x37, 0x38, 0x46, 0x46, + 0x37, 0x42, 0x35, 0x34, 0x45, 0x42, 0x35, 0x43, 0x3b, 0x3a, 0x43, 0x43, + 0x40, 0x42, 0x35, 0x3f, 0x38, 0x3f, 0x3a, 0x3a, 0x3b, 0x3f, 0x3e, 0x36, + 0x3f, 0x3c, 0x48, 0x3b, 0x3a, 0x41, 0x41, 0x35, 0x33, 0x3f, 0x3b, 0x45, + 0x48, 0x36, 0x40, 0x38, 0x47, 0x3d, 0x35, 0x40, 0x41, 0x42, 0x41, 0x37, + 0x41, 0x3e, 0x36, 0x48, 0x3e, 0x3c, 0x32, 0x39, 0x41, 0x40, 0x38, 0x3f, + 0x46, 0x43, 0x33, 0x40, 0x43, 0x43, 0x3a, 0x49, 0x3f, 0x35, 0x2c, 0x5d, + 0x43, 0x49, 0x52, 0x3b, 0x3c, 0x41, 0x40, 0x4a, 0x33, 0x50, 0x41, 0x46, + 0x52, 0x41, 0x68, 0x48, 0x44, 0x53, 0x54, 0x55, 0x42, 0x42, 0x57, 0x44, + 0x47, 0x35, 0x35, 0x3e, 0x4b, 0x44, 0x4e, 0x38, 0x55, 0x2f, 0x36, 0x2d, + 0x40, 0x48, 0x4b, 0x41, 0x48, 0x36, 0x32, 0x32, 0x44, 0x42, 0x47, 0x42, + 0x48, 0x3d, 0x3d, 0x39, 0x3e, 0x35, 0x4b, 0x39, 0x38, 0x3a, 0x39, 0x46, + 0x38, 0x3f, 0x3a, 0x42, 0x4b, 0x45, 0x3e, 0x32, 0x46, 0x43, 0x3b, 0x40, + 0x45, 0x41, 0x3e, 0x43, 0x37, 0x3d, 0x43, 0x3b, 0x46, 0x48, 0x42, 0x3b, + 0x3d, 0x48, 0x4a, 0x3c, 0x3b, 0x42, 0x40, 0x3c, 0x3a, 0x42, 0x38, 0x47, + 0x3b, 0x3b, 0x3d, 0x41, 0x3f, 0x38, 0x3f, 0x4a, 0x44, 0x3f, 0x47, 0x3a, + 0x47, 0x44, 0x43, 0x43, 0x34, 0x3d, 0x3a, 0x3c, 0x47, 0x3f, 0x3e, 0x39, + 0x42, 0x4a, 0x40, 0x36, 0x40, 0x41, 0x42, 0x3f, 0x3f, 0x43, 0x39, 0x38, + 0x3c, 0x3b, 0x4c, 0x2f, 0x41, 0x39, 0x40, 0x42, 0x3f, 0x42, 0x40, 0x36, + 0x3b, 0x45, 0x41, 0x41, 0x44, 0x45, 0x42, 0x37, 0x3d, 0x3a, 0x33, 0x3e, + 0x3b, 0x3b, 0x3c, 0x3d, 0x38, 0x49, 0x44, 0x39, 0x3f, 0x48, 0x3d, 0x41, + 0x42, 0x43, 0x44, 0x3e, 0x41, 0x3d, 0x32, 0x59, 0x45, 0x4b, 0x4b, 0x38, + 0x37, 0x3d, 0x48, 0x42, 0x3d, 0x52, 0x43, 0x46, 0x54, 0x48, 0x67, 0x4d, + 0x45, 0x4e, 0x49, 0x52, 0x45, 0x45, 0x58, 0x3b, 0x41, 0x38, 0x3f, 0x3f, + 0x49, 0x44, 0x4f, 0x48, 0x57, 0x31, 0x3c, 0x2a, 0x3e, 0x4c, 0x41, 0x40, + 0x47, 0x3f, 0x33, 0x34, 0x3f, 0x42, 0x48, 0x43, 0x4b, 0x38, 0x39, 0x3d, + 0x3f, 0x3e, 0x4b, 0x3f, 0x35, 0x36, 0x3c, 0x46, 0x3c, 0x45, 0x37, 0x3b, + 0x3c, 0x39, 0x41, 0x40, 0x41, 0x43, 0x44, 0x41, 0x45, 0x4f, 0x44, 0x43, + 0x44, 0x3c, 0x45, 0x34, 0x42, 0x45, 0x3f, 0x46, 0x3f, 0x43, 0x3d, 0x3a, + 0x39, 0x47, 0x45, 0x3d, 0x3f, 0x3b, 0x3d, 0x42, 0x38, 0x48, 0x48, 0x3b, + 0x3c, 0x3a, 0x3f, 0x41, 0x44, 0x4b, 0x44, 0x48, 0x41, 0x3c, 0x3d, 0x3c, + 0x3e, 0x3a, 0x4a, 0x3b, 0x49, 0x35, 0x3a, 0x3d, 0x41, 0x3f, 0x49, 0x39, + 0x44, 0x37, 0x3f, 0x3c, 0x42, 0x40, 0x4a, 0x46, 0x39, 0x38, 0x46, 0x37, + 0x41, 0x46, 0x41, 0x45, 0x40, 0x3b, 0x3b, 0x33, 0x3b, 0x39, 0x3c, 0x43, + 0x37, 0x3c, 0x44, 0x3d, 0x46, 0x39, 0x3c, 0x3c, 0x44, 0x48, 0x41, 0x44, + 0x41, 0x43, 0x46, 0x3b, 0x47, 0x41, 0x31, 0x41, 0x44, 0x40, 0x43, 0x42, + 0x3e, 0x43, 0x34, 0x65, 0x4f, 0x50, 0x4d, 0x3a, 0x37, 0x43, 0x4d, 0x4a, + 0x3d, 0x54, 0x40, 0x42, 0x5b, 0x3b, 0x71, 0x49, 0x44, 0x4f, 0x54, 0x56, + 0x48, 0x40, 0x52, 0x41, 0x42, 0x38, 0x3c, 0x49, 0x4a, 0x45, 0x51, 0x35, + 0x54, 0x2f, 0x35, 0x25, 0x4d, 0x3f, 0x4d, 0x43, 0x49, 0x33, 0x32, 0x3a, + 0x46, 0x48, 0x48, 0x3d, 0x43, 0x3a, 0x3c, 0x3a, 0x48, 0x40, 0x4b, 0x3b, + 0x45, 0x3b, 0x3f, 0x38, 0x37, 0x41, 0x31, 0x3b, 0x41, 0x43, 0x43, 0x37, + 0x48, 0x3f, 0x48, 0x37, 0x40, 0x4a, 0x43, 0x45, 0x3d, 0x39, 0x37, 0x37, + 0x3c, 0x3f, 0x47, 0x48, 0x43, 0x3e, 0x41, 0x3f, 0x3e, 0x38, 0x3e, 0x37, + 0x45, 0x45, 0x35, 0x44, 0x38, 0x3a, 0x49, 0x43, 0x40, 0x41, 0x40, 0x44, + 0x3c, 0x3e, 0x40, 0x38, 0x42, 0x41, 0x3c, 0x41, 0x3a, 0x3b, 0x3c, 0x3a, + 0x49, 0x3c, 0x42, 0x44, 0x3f, 0x39, 0x45, 0x32, 0x45, 0x43, 0x45, 0x39, + 0x43, 0x41, 0x4b, 0x39, 0x32, 0x3c, 0x3c, 0x36, 0x39, 0x3f, 0x46, 0x32, + 0x39, 0x35, 0x4f, 0x32, 0x3e, 0x40, 0x3d, 0x3e, 0x3a, 0x39, 0x4c, 0x38, + 0x43, 0x38, 0x49, 0x3b, 0x33, 0x39, 0x3b, 0x36, 0x36, 0x43, 0x3b, 0x3c, + 0x32, 0x3c, 0x3a, 0x45, 0x31, 0x3d, 0x37, 0x40, 0x3f, 0x3f, 0x35, 0xff, + 0x49, 0x4e, 0x4c, 0x3c, 0x36, 0x43, 0x46, 0x45, 0x41, 0x59, 0x44, 0x4a, + 0x53, 0x44, 0x71, 0x4a, 0x39, 0x4f, 0x50, 0x4b, 0x47, 0x42, 0x5a, 0x3c, + 0x45, 0x38, 0x3e, 0x42, 0x53, 0x43, 0x52, 0x3a, 0x52, 0x34, 0x31, 0x20, + 0x49, 0x4e, 0x46, 0x43, 0x4b, 0x3d, 0x2b, 0x27, 0x46, 0x46, 0x47, 0x41, + 0x42, 0x37, 0x39, 0x38, 0x45, 0x3f, 0x51, 0x3d, 0x48, 0x3f, 0x33, 0x3f, + 0x38, 0x45, 0x31, 0x38, 0x41, 0x3d, 0x47, 0x39, 0x42, 0x40, 0x4c, 0x3f, + 0x40, 0x42, 0x41, 0x41, 0x41, 0x42, 0x39, 0x35, 0x3f, 0x46, 0x45, 0x36, + 0x3f, 0x43, 0x3b, 0x39, 0x41, 0x38, 0x43, 0x37, 0x3d, 0x44, 0x3b, 0x40, + 0x36, 0x3d, 0x42, 0x41, 0x41, 0x3d, 0x38, 0x4a, 0x40, 0x4a, 0x4c, 0x38, + 0x3f, 0x40, 0x45, 0x3c, 0x3f, 0x4b, 0x43, 0x41, 0x43, 0x3e, 0x43, 0x3f, + 0x36, 0x40, 0x40, 0x39, 0x3f, 0x3a, 0x3a, 0x30, 0x41, 0x3c, 0x3c, 0x34, + 0x46, 0x38, 0x43, 0x34, 0x3a, 0x42, 0x43, 0x42, 0x40, 0x41, 0x49, 0x34, + 0x35, 0x40, 0x47, 0x3d, 0x3d, 0x3e, 0x4c, 0x33, 0x3c, 0x3b, 0x39, 0x43, + 0x3a, 0x3e, 0x3b, 0x37, 0x3f, 0x42, 0x31, 0x3d, 0x41, 0x3e, 0x32, 0x47, + 0x34, 0x41, 0x3d, 0x35, 0x39, 0x40, 0x38, 0x69, 0x4f, 0x4a, 0x49, 0x37, + 0x37, 0x44, 0x43, 0x46, 0x40, 0x58, 0x43, 0x48, 0x54, 0x46, 0x6c, 0x50, + 0x3a, 0x50, 0x50, 0x57, 0x47, 0x46, 0x5c, 0x40, 0x40, 0x39, 0x3e, 0x46, + 0x53, 0x46, 0x5c, 0x36, 0x4f, 0x32, 0x30, 0x2d, 0x4a, 0x48, 0x41, 0x45, + 0x47, 0x2f, 0x32, 0x2b, 0x43, 0x40, 0x43, 0x3c, 0x40, 0x44, 0x3e, 0x37, + 0x39, 0x3e, 0x48, 0x42, 0x45, 0x36, 0x47, 0x3f, 0x3b, 0x41, 0x35, 0x35, + 0x3b, 0x3e, 0x35, 0x43, 0x3e, 0x41, 0x3d, 0x36, 0x41, 0x3c, 0x40, 0x44, + 0x3d, 0x40, 0x35, 0x32, 0x48, 0x3e, 0x39, 0x42, 0x44, 0x3d, 0x39, 0x3b, + 0x3b, 0x45, 0x40, 0x4a, 0x3f, 0x41, 0x43, 0x39, 0x42, 0x44, 0x4c, 0x3c, + 0x3f, 0x3e, 0x3f, 0x43, 0x40, 0x42, 0x4c, 0x3b, 0x3e, 0x3d, 0x49, 0x42, + 0x40, 0x44, 0x40, 0x34, 0x36, 0x40, 0x45, 0x39, 0x42, 0x40, 0x3e, 0x44, + 0x45, 0x37, 0x3c, 0x38, 0x3e, 0x49, 0x3e, 0x3c, 0x41, 0x3d, 0x42, 0x32, + 0x40, 0x45, 0x3e, 0x36, 0x44, 0x3a, 0x4e, 0x38, 0x43, 0x38, 0x40, 0x38, + 0x49, 0x42, 0x40, 0x3d, 0x42, 0x48, 0x48, 0x3d, 0x41, 0x3a, 0x3f, 0x41, + 0x38, 0x3c, 0x44, 0x39, 0x3a, 0x32, 0x3a, 0x3e, 0x3d, 0x3b, 0x39, 0x38, + 0x3a, 0x43, 0x3a, 0x6b, 0x45, 0x50, 0x47, 0x33, 0x38, 0x48, 0x4d, 0x4f, + 0x39, 0x4b, 0x46, 0x4a, 0x4f, 0x42, 0x6f, 0x4b, 0x40, 0x55, 0x54, 0x50, + 0x42, 0x47, 0x5e, 0x46, 0x40, 0x34, 0x40, 0x47, 0x52, 0x46, 0x55, 0x3b, + 0x4f, 0x2b, 0x35, 0x33, 0x4c, 0x44, 0x44, 0x48, 0x47, 0x37, 0x35, 0x27, + 0x4a, 0x3b, 0x41, 0x40, 0x40, 0x3e, 0x36, 0x39, 0x3e, 0x3c, 0x45, 0x3f, + 0x4d, 0x41, 0x3d, 0x48, 0x47, 0x46, 0x33, 0x3d, 0x3d, 0x3e, 0x34, 0x3f, + 0x3e, 0x3a, 0x41, 0x35, 0x3b, 0x3e, 0x42, 0x3c, 0x42, 0x42, 0x40, 0x31, + 0x37, 0x40, 0x36, 0x42, 0x48, 0x39, 0x3d, 0x3c, 0x3a, 0x43, 0x39, 0x3d, + 0x47, 0x49, 0x43, 0x3d, 0x45, 0x39, 0x44, 0x37, 0x3e, 0x4d, 0x3d, 0x40, + 0x3d, 0x4c, 0x4d, 0x44, 0x3c, 0x3d, 0x46, 0x41, 0x41, 0x42, 0x40, 0x40, + 0x41, 0x3a, 0x3c, 0x3b, 0x3c, 0x44, 0x40, 0x34, 0x44, 0x38, 0x3b, 0x33, + 0x45, 0x45, 0x44, 0x3f, 0x3e, 0x3a, 0x3b, 0x3b, 0x43, 0x39, 0x3a, 0x45, + 0x3b, 0x3a, 0x4b, 0x39, 0x3d, 0x38, 0x41, 0x39, 0x42, 0x45, 0x43, 0x40, + 0x3e, 0x35, 0x44, 0x3f, 0x45, 0x41, 0x40, 0x3e, 0x43, 0x42, 0x37, 0x3a, + 0x38, 0x35, 0x3a, 0x48, 0x3e, 0x3b, 0x40, 0x38, 0x3c, 0x3c, 0x3b, 0x6a, + 0x48, 0x4d, 0x4d, 0x34, 0x38, 0x40, 0x4a, 0x45, 0x3c, 0x4f, 0x41, 0x4b, + 0x58, 0x46, 0x71, 0x49, 0x3d, 0x53, 0x44, 0x52, 0x42, 0x3e, 0x57, 0x4c, + 0x4c, 0x38, 0x40, 0x3b, 0x5c, 0x4c, 0x52, 0x3e, 0x4c, 0x2d, 0x32, 0x37, + 0x49, 0x3f, 0x41, 0x47, 0x4a, 0x3b, 0x2f, 0x26, 0x45, 0x40, 0x47, 0x42, + 0x3d, 0x39, 0x2d, 0x2c, 0x3f, 0x45, 0x46, 0x44, 0x48, 0x43, 0x42, 0x48, + 0x40, 0x41, 0x3b, 0x3b, 0x41, 0x3b, 0x39, 0x40, 0x3b, 0x47, 0x3f, 0x38, + 0x3f, 0x49, 0x3b, 0x35, 0x40, 0x45, 0x38, 0x35, 0x36, 0x34, 0x3e, 0x3d, + 0x46, 0x3e, 0x33, 0x38, 0x43, 0x48, 0x3f, 0x45, 0x31, 0x44, 0x38, 0x35, + 0x3c, 0x41, 0x4b, 0x44, 0x3d, 0x43, 0x38, 0x48, 0x3c, 0x39, 0x4a, 0x42, + 0x3d, 0x43, 0x3f, 0x49, 0x3e, 0x47, 0x49, 0x41, 0x3b, 0x3c, 0x47, 0x3a, + 0x3d, 0x40, 0x4a, 0x38, 0x3d, 0x3b, 0x47, 0x3a, 0x36, 0x47, 0x42, 0x46, + 0x3c, 0x3d, 0x45, 0x3b, 0x48, 0x3f, 0x38, 0x36, 0x39, 0x46, 0x43, 0x3a, + 0x41, 0x3d, 0x39, 0x39, 0x46, 0x37, 0x3f, 0x3f, 0x3a, 0x46, 0x3f, 0x39, + 0x49, 0x44, 0x42, 0x3a, 0x3a, 0x43, 0x3e, 0x42, 0x3d, 0x3d, 0x43, 0x40, + 0x43, 0x3c, 0x3f, 0x43, 0x40, 0x42, 0x3b, 0x57, 0x4a, 0x4f, 0x4a, 0x2d, + 0x3b, 0x48, 0x45, 0x42, 0x34, 0x4c, 0x3e, 0x4f, 0x4d, 0x40, 0x6c, 0x4b, + 0x3b, 0x4d, 0x4c, 0x57, 0x49, 0x3d, 0x5d, 0x44, 0x43, 0x29, 0x42, 0x3f, + 0x5b, 0x47, 0x4f, 0x3e, 0x54, 0x2e, 0x34, 0x34, 0x4b, 0x47, 0x46, 0x46, + 0x4b, 0x34, 0x36, 0x28, 0x3e, 0x3f, 0x42, 0x40, 0x3b, 0x38, 0x39, 0x42, + 0x49, 0x3d, 0x49, 0x47, 0x47, 0x3b, 0x43, 0x34, 0x39, 0x36, 0x42, 0x3d, + 0x37, 0x40, 0x37, 0x38, 0x46, 0x42, 0x49, 0x37, 0x44, 0x3f, 0x38, 0x3e, + 0x36, 0x32, 0x33, 0x38, 0x40, 0x46, 0x42, 0x34, 0x41, 0x42, 0x3e, 0x38, + 0x44, 0x3e, 0x3f, 0x43, 0x3f, 0x43, 0x35, 0x3f, 0x4d, 0x3b, 0x43, 0x39, + 0x40, 0x47, 0x3f, 0x4a, 0x3a, 0x3f, 0x45, 0x45, 0x48, 0x42, 0x3b, 0x47, + 0x42, 0x4b, 0x47, 0x3e, 0x3c, 0x42, 0x46, 0x39, 0x41, 0x3f, 0x48, 0x33, + 0x45, 0x34, 0x3d, 0x30, 0x40, 0x4c, 0x40, 0x40, 0x39, 0x37, 0x40, 0x33, + 0x49, 0x42, 0x45, 0x38, 0x3c, 0x43, 0x45, 0x35, 0x37, 0x33, 0x34, 0x3b, + 0x3b, 0x38, 0x39, 0x41, 0x42, 0x40, 0x3e, 0x3e, 0x41, 0x33, 0x3a, 0x36, + 0x40, 0x3a, 0x3c, 0x45, 0x43, 0x3c, 0x40, 0x41, 0x49, 0x47, 0x35, 0x34, + 0x3a, 0x3d, 0x3a, 0x68, 0x4f, 0x48, 0x43, 0x36, 0x37, 0x3e, 0x45, 0x49, + 0x3a, 0x4d, 0x41, 0x3d, 0x46, 0x45, 0x65, 0x46, 0x38, 0x4d, 0x4a, 0x53, + 0x43, 0x41, 0x5d, 0x47, 0x41, 0x34, 0x39, 0x43, 0x4e, 0x48, 0x50, 0x38, + 0x53, 0x32, 0x30, 0x2e, 0x49, 0x4c, 0x4d, 0x3f, 0x46, 0x38, 0x34, 0x2b, + 0x44, 0x44, 0x41, 0x41, 0x36, 0x40, 0x3f, 0x32, 0x46, 0x38, 0x50, 0x45, + 0x3f, 0x3d, 0x3b, 0x36, 0x3b, 0x43, 0x3a, 0x34, 0x36, 0x3f, 0x39, 0x35, + 0x3c, 0x40, 0x40, 0x37, 0x3c, 0x39, 0x3d, 0x36, 0x48, 0x3d, 0x43, 0x34, + 0x3b, 0x46, 0x43, 0x41, 0x33, 0x3e, 0x44, 0x3d, 0x44, 0x44, 0x4c, 0x3c, + 0x37, 0x49, 0x42, 0x35, 0x45, 0x3a, 0x3c, 0x41, 0x3a, 0x45, 0x46, 0x41, + 0x3c, 0x48, 0x46, 0x36, 0x36, 0x42, 0x3b, 0x46, 0x42, 0x45, 0x44, 0x47, + 0x3f, 0x44, 0x3a, 0x35, 0x37, 0x46, 0x40, 0x38, 0x40, 0x3d, 0x36, 0x2c, + 0x34, 0x47, 0x40, 0x38, 0x3f, 0x3f, 0x44, 0x2d, 0x3b, 0x3d, 0x3e, 0x44, + 0x3c, 0x40, 0x3e, 0x33, 0x3c, 0x3a, 0x49, 0x40, 0x42, 0x42, 0x3a, 0x3b, + 0x33, 0x3d, 0x3c, 0x43, 0x3e, 0x3d, 0x3a, 0x3a, 0x48, 0x3e, 0x3c, 0x39, + 0x3f, 0x44, 0x37, 0x40, 0x3f, 0x3c, 0x3e, 0x3d, 0x38, 0x42, 0x34, 0x62, + 0x51, 0x47, 0x44, 0x3f, 0x32, 0x3c, 0x3f, 0x46, 0x3d, 0x46, 0x3e, 0x45, + 0x4a, 0x3e, 0x5d, 0x43, 0x45, 0x49, 0x4a, 0x55, 0x41, 0x3c, 0x5a, 0x44, + 0x43, 0x3b, 0x3c, 0x3a, 0x4b, 0x4e, 0x4d, 0x42, 0x49, 0x30, 0x3b, 0x38, + 0x42, 0x44, 0x51, 0x40, 0x48, 0x33, 0x3f, 0x2b, 0x3c, 0x41, 0x3c, 0x45, + 0x35, 0x39, 0x42, 0x37, 0x40, 0x46, 0x46, 0x3f, 0x41, 0x45, 0x42, 0x3d, + 0x43, 0x38, 0x3e, 0x38, 0x3c, 0x39, 0x40, 0x38, 0x37, 0x36, 0x3d, 0x3d, + 0x38, 0x47, 0x45, 0x3b, 0x45, 0x44, 0x42, 0x2e, 0x37, 0x40, 0x42, 0x42, + 0x3c, 0x36, 0x3b, 0x39, 0x44, 0x4d, 0x42, 0x3f, 0x3a, 0x3e, 0x45, 0x34, + 0x3c, 0x43, 0x47, 0x43, 0x3f, 0x48, 0x3b, 0x44, 0x3d, 0x44, 0x43, 0x3e, + 0x40, 0x4a, 0x31, 0x42, 0x42, 0x43, 0x48, 0x45, 0x3a, 0x42, 0x36, 0x2f, + 0x3c, 0x3e, 0x3b, 0x3b, 0x44, 0x3f, 0x3a, 0x2c, 0x47, 0x3f, 0x4a, 0x40, + 0x40, 0x40, 0x3c, 0x2a, 0x3e, 0x44, 0x40, 0x43, 0x3a, 0x42, 0x39, 0x34, + 0x49, 0x3e, 0x36, 0x42, 0x3f, 0x42, 0x33, 0x3b, 0x3c, 0x45, 0x39, 0x3f, + 0x3e, 0x3f, 0x41, 0x3d, 0x32, 0x3b, 0x31, 0x40, 0x3f, 0x44, 0x3c, 0x3f, + 0x40, 0x46, 0x45, 0x36, 0x36, 0x42, 0x30, 0x57, 0x47, 0x44, 0x48, 0x3f, + 0x35, 0x37, 0x3f, 0x3f, 0x38, 0x4a, 0x41, 0x46, 0x50, 0x3d, 0x5b, 0x41, + 0x3e, 0x3c, 0x4a, 0x54, 0x45, 0x41, 0x5b, 0x46, 0x3d, 0x3b, 0x43, 0x33, + 0x45, 0x4e, 0x43, 0x3b, 0x44, 0x37, 0x37, 0x32, 0x4c, 0x3d, 0x4c, 0x3f, + 0x49, 0x3b, 0x37, 0x3a, 0x33, 0x43, 0x3f, 0x40, 0x44, 0x36, 0x3b, 0x44, + 0x45, 0x40, 0x3c, 0x3c, 0x41, 0x44, 0x3b, 0x3d, 0x33, 0x37, 0x3c, 0x35, + 0x3d, 0x3f, 0x39, 0x38, 0x33, 0x43, 0x3e, 0x39, 0x3b, 0x3e, 0x41, 0x35, + 0x40, 0x46, 0x43, 0x35, 0x41, 0x3d, 0x32, 0x39, 0x3c, 0x40, 0x3e, 0x3f, + 0x42, 0x38, 0x3b, 0x45, 0x3a, 0x3d, 0x40, 0x36, 0x3a, 0x40, 0x46, 0x44, + 0x48, 0x45, 0x3f, 0x3a, 0x45, 0x45, 0x3c, 0x3b, 0x40, 0x4c, 0x39, 0x3a, + 0x38, 0x39, 0x46, 0x3a, 0x3e, 0x4b, 0x34, 0x39, 0x3d, 0x3f, 0x40, 0x39, + 0x45, 0x31, 0x45, 0x29, 0x3f, 0x38, 0x3a, 0x3f, 0x38, 0x3b, 0x36, 0x2d, + 0x43, 0x3d, 0x45, 0x3c, 0x46, 0x3f, 0x40, 0x3c, 0x3a, 0x3e, 0x3d, 0x38, + 0x3f, 0x3c, 0x3f, 0x42, 0x35, 0x3f, 0x3a, 0x43, 0x3d, 0x43, 0x3d, 0x33, + 0x3d, 0x48, 0x42, 0x3d, 0x45, 0x46, 0x3d, 0x35, 0x32, 0x44, 0x42, 0x37, + 0x3d, 0x40, 0x3c, 0x47, 0x4a, 0x45, 0x47, 0x2f, 0x33, 0x36, 0x3f, 0x42, + 0x38, 0x43, 0x3e, 0x3a, 0x41, 0x3f, 0x5f, 0x3f, 0x48, 0x3a, 0x44, 0x47, + 0x41, 0x3e, 0x57, 0x42, 0x41, 0x33, 0x34, 0x39, 0x42, 0x44, 0x42, 0x3c, + 0x49, 0x34, 0x37, 0x33, 0x47, 0x38, 0x43, 0x3d, 0x43, 0x3e, 0x3e, 0x36, + 0x41, 0x41, 0x37, 0x40, 0x39, 0x3e, 0x3b, 0x3b, 0x3e, 0x41, 0x3d, 0x3b, + 0x43, 0x3e, 0x39, 0x43, 0x2f, 0x3e, 0x33, 0x40, 0x45, 0x47, 0x30, 0x46, + 0x3f, 0x3f, 0x37, 0x42, 0x3d, 0x42, 0x43, 0x37, 0x38, 0x3c, 0x35, 0x34, + 0x41, 0x43, 0x3e, 0x3e, 0x3f, 0x49, 0x35, 0x35, 0x38, 0x36, 0x3a, 0x43, + 0x38, 0x46, 0x48, 0x36, 0x3f, 0x39, 0x3b, 0x3e, 0x48, 0x47, 0x41, 0x34, + 0x3b, 0x3c, 0x37, 0x3e, 0x40, 0x41, 0x3b, 0x3d, 0x43, 0x42, 0x3a, 0x39, + 0x3b, 0x43, 0x38, 0x2b, 0x43, 0x41, 0x48, 0x35, 0x44, 0x44, 0x3e, 0x2c, + 0x46, 0x40, 0x3e, 0x41, 0x38, 0x34, 0x35, 0x37, 0x34, 0x3f, 0x3d, 0x46, + 0x33, 0x3c, 0x3c, 0x2e, 0x3b, 0x45, 0x3d, 0x3e, 0x3a, 0x42, 0x3c, 0x36, + 0x3a, 0x42, 0x39, 0x43, 0x35, 0x39, 0x40, 0x44, 0x47, 0x41, 0x44, 0x3d, + 0x41, 0x3e, 0x38, 0x39, 0x45, 0x3a, 0x35, 0x43, 0x3f, 0x44, 0x41, 0x49, + 0x47, 0x3f, 0x44, 0x40, 0x38, 0x43, 0x40, 0x3e, 0x39, 0x42, 0x32, 0x3b, + 0x42, 0x47, 0x57, 0x37, 0x36, 0x38, 0x43, 0x49, 0x3b, 0x34, 0x54, 0x42, + 0x3d, 0x3f, 0x3e, 0x3b, 0x38, 0x41, 0x43, 0x3a, 0x44, 0x39, 0x34, 0x2c, + 0x38, 0x43, 0x4b, 0x3f, 0x40, 0x3e, 0x32, 0x33, 0x3d, 0x44, 0x45, 0x44, + 0x3e, 0x35, 0x37, 0x39, 0x40, 0x3e, 0x40, 0x3c, 0x34, 0x43, 0x37, 0x40, + 0x39, 0x3e, 0x3d, 0x43, 0x3a, 0x44, 0x43, 0x44, 0x3d, 0x3b, 0x45, 0x3b, + 0x3a, 0x3a, 0x3f, 0x37, 0x43, 0x3b, 0x33, 0x35, 0x40, 0x47, 0x3e, 0x3c, + 0x39, 0x3c, 0x34, 0x29, 0x3c, 0x3e, 0x46, 0x3e, 0x3c, 0x38, 0x3f, 0x2d, + 0x3d, 0x3d, 0x3f, 0x3f, 0x3d, 0x45, 0x3b, 0x32, 0x39, 0x3f, 0x41, 0x38, + 0x36, 0x3e, 0x3a, 0x35, 0x40, 0x3f, 0x3b, 0x32, 0x3c, 0x39, 0x3e, 0x35, + 0x3e, 0x45, 0x34, 0x38, 0x44, 0x39, 0x3f, 0x31, 0x34, 0x39, 0x3f, 0x38, + 0x44, 0x42, 0x3f, 0x3b, 0x39, 0x3d, 0x39, 0x3b, 0x44, 0x46, 0x38, 0x3d, + 0x45, 0x37, 0x40, 0x3a, 0x3a, 0x39, 0x35, 0x3c, 0x39, 0x40, 0x47, 0x3e, + 0x38, 0x42, 0x41, 0x3b, 0x48, 0x3f, 0x3a, 0x3e, 0x3d, 0x3f, 0x32, 0x3b, + 0x3f, 0x3d, 0x3e, 0x44, 0x43, 0x41, 0x44, 0x47, 0x48, 0x41, 0x41, 0x36, + 0x3a, 0x33, 0x3c, 0x3c, 0x37, 0x3e, 0x40, 0x34, 0x3f, 0x42, 0x53, 0x40, + 0x3f, 0x35, 0x3e, 0x46, 0x3a, 0x3e, 0x4b, 0x41, 0x46, 0x32, 0x39, 0x36, + 0x3b, 0x4f, 0x36, 0x3c, 0x40, 0x3a, 0x40, 0x40, 0x47, 0x3e, 0x49, 0x37, + 0x3f, 0x31, 0x3e, 0x40, 0x3b, 0x3f, 0x43, 0x44, 0x3a, 0x3d, 0x31, 0x41, + 0x41, 0x33, 0x43, 0x40, 0x3c, 0x3a, 0x41, 0x40, 0x37, 0x3f, 0x34, 0x3e, + 0x44, 0x42, 0x3d, 0x3f, 0x3f, 0x34, 0x36, 0x34, 0x31, 0x41, 0x32, 0x39, + 0x3e, 0x3d, 0x42, 0x35, 0x3e, 0x3a, 0x41, 0x47, 0x3d, 0x42, 0x33, 0x32, + 0x43, 0x42, 0x36, 0x41, 0x3e, 0x39, 0x46, 0x39, 0x35, 0x3d, 0x3d, 0x40, + 0x38, 0x44, 0x3d, 0x31, 0x44, 0x39, 0x3a, 0x45, 0x42, 0x41, 0x3d, 0x36, + 0x3f, 0x3c, 0x39, 0x3d, 0x32, 0x39, 0x42, 0x34, 0x3f, 0x38, 0x44, 0x3c, + 0x43, 0x45, 0x41, 0x2d, 0x44, 0x42, 0x3d, 0x3f, 0x44, 0x38, 0x3d, 0x35, + 0x3a, 0x48, 0x40, 0x3b, 0x3d, 0x36, 0x3b, 0x40, 0x3f, 0x3a, 0x3a, 0x3f, + 0x3c, 0x33, 0x39, 0x3c, 0x3c, 0x38, 0x47, 0x36, 0x3d, 0x41, 0x46, 0x41, + 0x34, 0x46, 0x48, 0x46, 0x3d, 0x3c, 0x40, 0x43, 0x3d, 0x41, 0x37, 0x3e, + 0x39, 0x47, 0x3f, 0x39, 0x46, 0x43, 0x3f, 0x41, 0x45, 0x37, 0x40, 0x3a, + 0x3d, 0x44, 0x3f, 0x3b, 0x3b, 0x40, 0x4f, 0x3d, 0x3d, 0x41, 0x3c, 0x43, + 0x3e, 0x46, 0x4e, 0x40, 0x3f, 0x34, 0x48, 0x29, 0x45, 0x44, 0x46, 0x41, + 0x45, 0x32, 0x3e, 0x38, 0x39, 0x3a, 0x3e, 0x3e, 0x4c, 0x34, 0x3c, 0x40, + 0x4a, 0x44, 0x3d, 0x46, 0x3b, 0x3e, 0x42, 0x42, 0x3a, 0x41, 0x43, 0x41, + 0x39, 0x3f, 0x3e, 0x3c, 0x36, 0x48, 0x3f, 0x3e, 0x3e, 0x37, 0x3f, 0x3f, + 0x3b, 0x40, 0x3e, 0x35, 0x32, 0x35, 0x3f, 0x33, 0x3f, 0x38, 0x43, 0x37, + 0x49, 0x38, 0x37, 0x3c, 0x3c, 0x40, 0x40, 0x3a, 0x3a, 0x46, 0x37, 0x34, + 0x34, 0x3b, 0x3d, 0x2f, 0x3a, 0x38, 0x3d, 0x46, 0x3d, 0x3b, 0x3d, 0x38, + 0x35, 0x37, 0x44, 0x3c, 0x3d, 0x3e, 0x40, 0x3a, 0x40, 0x33, 0x3e, 0x38, + 0x40, 0x3e, 0x45, 0x37, 0x3f, 0x3b, 0x3c, 0x40, 0x3b, 0x3c, 0x3b, 0x33, + 0x41, 0x3f, 0x3b, 0x42, 0x31, 0x3b, 0x3a, 0x39, 0x3d, 0x41, 0x39, 0x40, + 0x43, 0x45, 0x39, 0x3b, 0x3a, 0x42, 0x43, 0x3d, 0x3f, 0x40, 0x47, 0x39, + 0x37, 0x3f, 0x47, 0x3f, 0x45, 0x41, 0x39, 0x3a, 0x41, 0x38, 0x3c, 0x3c, + 0x39, 0x40, 0x39, 0x3b, 0x3b, 0x3e, 0x38, 0x3b, 0x37, 0x48, 0x41, 0x3f, + 0x3e, 0x37, 0x3d, 0x44, 0x3c, 0x3e, 0x40, 0x39, 0x41, 0x42, 0x3d, 0x45, + 0x3b, 0x3e, 0x4c, 0x3b, 0x3a, 0x3a, 0x3e, 0x47, 0x3c, 0x3f, 0x48, 0x3f, + 0x46, 0x3f, 0x39, 0x25, 0x44, 0x3a, 0x3b, 0x40, 0x41, 0x39, 0x39, 0x47, + 0x3b, 0x32, 0x49, 0x42, 0x41, 0x3a, 0x43, 0x41, 0x3e, 0x35, 0x37, 0x3d, + 0x49, 0x40, 0x45, 0x3b, 0x3c, 0x38, 0x48, 0x3c, 0x3c, 0x35, 0x3f, 0x41, + 0x41, 0x4c, 0x36, 0x39, 0x37, 0x3d, 0x3b, 0x3e, 0x44, 0x32, 0x3d, 0x3f, + 0x3a, 0x3b, 0x3a, 0x47, 0x38, 0x42, 0x36, 0x34, 0x43, 0x3f, 0x3e, 0x40, + 0x34, 0x31, 0x36, 0x33, 0x42, 0x37, 0x41, 0x41, 0x40, 0x3d, 0x3d, 0x37, + 0x43, 0x3a, 0x3e, 0x44, 0x43, 0x3c, 0x35, 0x38, 0x38, 0x3c, 0x43, 0x36, + 0x3a, 0x38, 0x40, 0x3f, 0x3d, 0x3e, 0x37, 0x3b, 0x41, 0x3a, 0x3b, 0x3d, + 0x3c, 0x41, 0x3c, 0x41, 0x47, 0x3f, 0x3f, 0x3b, 0x3d, 0x3f, 0x3b, 0x45, + 0x38, 0x38, 0x40, 0x38, 0x46, 0x42, 0x39, 0x3d, 0x3d, 0x3b, 0x42, 0x36, + 0x42, 0x41, 0x3e, 0x3e, 0x36, 0x3f, 0x37, 0x3f, 0x36, 0x48, 0x3b, 0x39, + 0x3d, 0x3f, 0x43, 0x3e, 0x3c, 0x40, 0x48, 0x46, 0x43, 0x36, 0x42, 0x39, + 0x46, 0x3c, 0x37, 0x38, 0x49, 0x37, 0x36, 0x39, 0x3e, 0x42, 0x48, 0x3a, + 0x3c, 0x3e, 0x42, 0x30, 0x3e, 0x34, 0x39, 0x3b, 0x46, 0x61, 0x46, 0x1e, + 0x4c, 0x3b, 0x40, 0x2d, 0x3c, 0x42, 0x32, 0x30, 0x49, 0x3e, 0x39, 0x34, + 0x30, 0x40, 0x31, 0x38, 0x40, 0x3d, 0x3c, 0x35, 0x3a, 0x36, 0x40, 0x3b, + 0x41, 0x40, 0x3b, 0x39, 0x37, 0x37, 0x3f, 0x3b, 0x3c, 0x3a, 0x40, 0x3a, + 0x36, 0x3c, 0x42, 0x39, 0x3e, 0x36, 0x40, 0x42, 0x39, 0x40, 0x3b, 0x34, + 0x37, 0x33, 0x36, 0x3f, 0x43, 0x33, 0x33, 0x27, 0x3d, 0x46, 0x40, 0x31, + 0x38, 0x3e, 0x41, 0x20, 0x3f, 0x39, 0x42, 0x35, 0x35, 0x45, 0x40, 0x1e, + 0x32, 0x35, 0x32, 0x3c, 0x35, 0x44, 0x46, 0x29, 0x3a, 0x3d, 0x37, 0x42, + 0x3b, 0x45, 0x3a, 0x26, 0x38, 0x40, 0x30, 0x37, 0x41, 0x40, 0x39, 0x2b, + 0x49, 0x3f, 0x43, 0x43, 0x40, 0x3a, 0x38, 0x29, 0x43, 0x3a, 0x37, 0x40, + 0x3f, 0x35, 0x3a, 0x28, 0x36, 0x3e, 0x3f, 0x43, 0x3c, 0x39, 0x42, 0x2c, + 0x38, 0x42, 0x38, 0x3d, 0x42, 0x38, 0x35, 0x2d, 0x34, 0x38, 0x3d, 0x43, + 0x46, 0x3e, 0x3c, 0x27, 0x3e, 0x40, 0x46, 0x39, 0x35, 0x3d, 0x42, 0x35, + 0x42, 0x36, 0x40, 0x3e, 0x3a, 0x3e, 0x3c, 0x37, 0x3a, 0x3c, 0x48, 0x48, + 0x48, 0x37, 0x3d, 0x38, 0x4b, 0x40, 0x43, 0x3b, 0x41, 0x46, 0x3c, 0x34, + 0x46, 0x3c, 0x3c, 0x3c, 0x4b, 0x64, 0x4a, 0x22, 0x52, 0x41, 0x42, 0x3b, + 0x42, 0x4a, 0x34, 0x37, 0x4b, 0x44, 0x3b, 0x4a, 0x38, 0x3f, 0x38, 0x3a, + 0x40, 0x41, 0x42, 0x3c, 0x33, 0x3e, 0x3c, 0x42, 0x2c, 0x4e, 0x47, 0x3f, + 0x38, 0x33, 0x39, 0x3f, 0x3b, 0x45, 0x37, 0x3a, 0x42, 0x42, 0x44, 0x3f, + 0x3c, 0x3c, 0x3e, 0x3d, 0x3c, 0x3c, 0x40, 0x2c, 0x3c, 0x3d, 0x42, 0x39, + 0x3a, 0x37, 0x43, 0x2a, 0x3d, 0x40, 0x41, 0x41, 0x46, 0x46, 0x42, 0x28, + 0x39, 0x3c, 0x37, 0x44, 0x46, 0x41, 0x47, 0x2b, 0x44, 0x33, 0x39, 0x3f, + 0x3f, 0x43, 0x3d, 0x23, 0x3a, 0x43, 0x41, 0x3b, 0x41, 0x42, 0x33, 0x1f, + 0x43, 0x3e, 0x3d, 0x40, 0x37, 0x33, 0x42, 0x28, 0x3b, 0x38, 0x37, 0x3c, + 0x34, 0x40, 0x44, 0x2a, 0x3c, 0x3a, 0x41, 0x37, 0x45, 0x3f, 0x3e, 0x26, + 0x41, 0x40, 0x35, 0x3d, 0x45, 0x3e, 0x3d, 0x29, 0x3c, 0x39, 0x3f, 0x3c, + 0x3d, 0x39, 0x38, 0x2d, 0x39, 0x38, 0x38, 0x44, 0x3c, 0x3e, 0x38, 0x26, + 0x40, 0x36, 0x39, 0x38, 0x3f, 0x32, 0x39, 0x35, 0x3d, 0x3e, 0x35, 0x3a, + 0x3f, 0x3f, 0x31, 0x35, 0x34, 0x45, 0x3e, 0x43, 0x48, 0x3b, 0x37, 0x39, + 0x4d, 0x46, 0x54, 0x40, 0x41, 0x4e, 0x3d, 0x38, 0x4d, 0x38, 0x3a, 0x3b, + 0x49, 0x5a, 0x4a, 0x1e, 0x5e, 0x39, 0x38, 0x37, 0x3a, 0x51, 0x3a, 0x3c, + 0x50, 0x3f, 0x40, 0x42, 0x33, 0x3b, 0x2e, 0x4a, 0x3f, 0x4a, 0x3b, 0x43, + 0x36, 0x3e, 0x3d, 0x42, 0x39, 0x46, 0x4b, 0x3c, 0x3b, 0x3b, 0x35, 0x3e, + 0x3d, 0x4b, 0x3f, 0x41, 0x3f, 0x3b, 0x42, 0x42, 0x38, 0x3a, 0x41, 0x3d, + 0x36, 0x41, 0x37, 0x2f, 0x38, 0x37, 0x3f, 0x34, 0x35, 0x35, 0x45, 0x30, + 0x31, 0x42, 0x31, 0x3a, 0x3a, 0x3e, 0x3d, 0x23, 0x3f, 0x43, 0x3b, 0x41, + 0x35, 0x3b, 0x40, 0x25, 0x45, 0x3e, 0x42, 0x3b, 0x31, 0x40, 0x36, 0x28, + 0x43, 0x42, 0x30, 0x42, 0x32, 0x32, 0x36, 0x2c, 0x35, 0x3a, 0x3d, 0x3a, + 0x3c, 0x36, 0x3e, 0x30, 0x41, 0x42, 0x38, 0x41, 0x41, 0x3e, 0x3c, 0x23, + 0x37, 0x40, 0x3c, 0x3e, 0x3e, 0x3a, 0x37, 0x2b, 0x36, 0x40, 0x41, 0x42, + 0x3e, 0x38, 0x44, 0x22, 0x46, 0x38, 0x33, 0x3b, 0x3a, 0x3a, 0x3a, 0x24, + 0x36, 0x3b, 0x38, 0x44, 0x34, 0x38, 0x40, 0x28, 0x38, 0x3d, 0x36, 0x44, + 0x31, 0x3e, 0x37, 0x37, 0x36, 0x3f, 0x47, 0x38, 0x3b, 0x3e, 0x2c, 0x4c, + 0x36, 0x3c, 0x3b, 0x41, 0x4c, 0x3d, 0x3d, 0x40, 0x49, 0x44, 0x52, 0x3f, + 0x3b, 0x4d, 0x3c, 0x3a, 0x4f, 0x3b, 0x36, 0x3b, 0x4a, 0x5f, 0x4e, 0x1f, + 0x57, 0x3c, 0x3d, 0x3d, 0x46, 0x59, 0x42, 0x45, 0x52, 0x3d, 0x3a, 0x41, + 0x31, 0x39, 0x39, 0x4f, 0x43, 0x4e, 0x3e, 0x37, 0x3a, 0x37, 0x33, 0x47, + 0x32, 0x45, 0x47, 0x43, 0x31, 0x33, 0x38, 0x43, 0x3e, 0x47, 0x3d, 0x32, + 0x3b, 0x39, 0x3c, 0x42, 0x3d, 0x47, 0x42, 0x40, 0x3d, 0x3f, 0x3c, 0x34, + 0x3b, 0x3e, 0x42, 0x3d, 0x43, 0x35, 0x42, 0x2c, 0x35, 0x3d, 0x3c, 0x3d, + 0x3a, 0x3c, 0x46, 0x25, 0x43, 0x35, 0x3d, 0x39, 0x3a, 0x3c, 0x40, 0x2b, + 0x33, 0x40, 0x3d, 0x46, 0x45, 0x37, 0x3c, 0x36, 0x43, 0x37, 0x3e, 0x3a, + 0x3c, 0x47, 0x3f, 0x38, 0x36, 0x3e, 0x3a, 0x42, 0x3c, 0x42, 0x33, 0x39, + 0x3c, 0x3a, 0x3c, 0x40, 0x48, 0x3b, 0x40, 0x32, 0x37, 0x47, 0x34, 0x38, + 0x33, 0x3d, 0x49, 0x2d, 0x36, 0x42, 0x3d, 0x3e, 0x47, 0x3c, 0x42, 0x2c, + 0x3b, 0x31, 0x3f, 0x3c, 0x3d, 0x3c, 0x3f, 0x2b, 0x41, 0x35, 0x33, 0x43, + 0x47, 0x39, 0x34, 0x2a, 0x3a, 0x3a, 0x40, 0x3d, 0x44, 0x3c, 0x39, 0x34, + 0x43, 0x40, 0x33, 0x3a, 0x3b, 0x42, 0x38, 0x3b, 0x34, 0x35, 0x40, 0x43, + 0x4b, 0x41, 0x3d, 0x38, 0x49, 0x44, 0x4d, 0x37, 0x3a, 0x4b, 0x40, 0x39, + 0x4e, 0x3b, 0x30, 0x38, 0x47, 0x5d, 0x50, 0x1f, 0x54, 0x35, 0x3a, 0x39, + 0x40, 0x4c, 0x46, 0x42, 0x52, 0x39, 0x39, 0x45, 0x41, 0x3c, 0x30, 0x5b, + 0x43, 0x4d, 0x4a, 0x3e, 0x31, 0x39, 0x41, 0x4c, 0x36, 0x44, 0x4c, 0x39, + 0x32, 0x41, 0x47, 0x3e, 0x34, 0x49, 0x45, 0x3b, 0x34, 0x3a, 0x3b, 0x47, + 0x43, 0x3e, 0x43, 0x32, 0x40, 0x3e, 0x3e, 0x38, 0x37, 0x3e, 0x37, 0x3a, + 0x3a, 0x40, 0x48, 0x2f, 0x3e, 0x3e, 0x46, 0x3a, 0x3e, 0x35, 0x49, 0x30, + 0x3a, 0x41, 0x3e, 0x39, 0x34, 0x45, 0x3d, 0x34, 0x48, 0x43, 0x43, 0x42, + 0x33, 0x39, 0x3b, 0x3f, 0x30, 0x46, 0x41, 0x39, 0x48, 0x3a, 0x3c, 0x3e, + 0x3f, 0x36, 0x40, 0x3d, 0x43, 0x40, 0x3e, 0x39, 0x44, 0x40, 0x44, 0x3b, + 0x43, 0x42, 0x39, 0x38, 0x3a, 0x3f, 0x3b, 0x3f, 0x38, 0x3d, 0x34, 0x30, + 0x34, 0x3d, 0x3f, 0x42, 0x44, 0x3e, 0x34, 0x32, 0x37, 0x46, 0x44, 0x38, + 0x3c, 0x45, 0x39, 0x2b, 0x41, 0x3c, 0x40, 0x40, 0x3a, 0x3a, 0x3c, 0x32, + 0x45, 0x42, 0x3d, 0x46, 0x38, 0x3b, 0x34, 0x35, 0x38, 0x43, 0x3d, 0x34, + 0x42, 0x3b, 0x38, 0x3d, 0x37, 0x43, 0x3f, 0x39, 0x4e, 0x39, 0x40, 0x3f, + 0x4d, 0x43, 0x49, 0x3f, 0x36, 0x41, 0x44, 0x39, 0x48, 0x3a, 0x35, 0x39, + 0x48, 0x59, 0x4e, 0x25, 0x58, 0x39, 0x42, 0x35, 0x43, 0x4e, 0x42, 0x3f, + 0x4a, 0x43, 0x3b, 0x3f, 0x3b, 0x37, 0x2b, 0x5a, 0x3d, 0x44, 0x3b, 0x40, + 0x31, 0x38, 0x37, 0x44, 0x32, 0x3e, 0x41, 0x3d, 0x2c, 0x42, 0x42, 0x3c, + 0x37, 0x45, 0x41, 0x41, 0x3d, 0x39, 0x41, 0x40, 0x3a, 0x46, 0x41, 0x40, + 0x40, 0x3d, 0x38, 0x31, 0x37, 0x3f, 0x42, 0x38, 0x3f, 0x3c, 0x48, 0x30, + 0x3e, 0x39, 0x3f, 0x3d, 0x3d, 0x44, 0x52, 0x35, 0x3b, 0x32, 0x42, 0x32, + 0x3a, 0x43, 0x39, 0x3b, 0x31, 0x43, 0x36, 0x3c, 0x3c, 0x3c, 0x41, 0x45, + 0x42, 0x49, 0x41, 0x3b, 0x42, 0x3e, 0x41, 0x44, 0x36, 0x41, 0x3f, 0x3c, + 0x3e, 0x47, 0x45, 0x41, 0x38, 0x41, 0x3f, 0x43, 0x35, 0x32, 0x41, 0x39, + 0x36, 0x47, 0x35, 0x42, 0x44, 0x3b, 0x3f, 0x34, 0x48, 0x41, 0x43, 0x42, + 0x36, 0x3e, 0x3c, 0x3d, 0x3d, 0x3b, 0x42, 0x44, 0x3a, 0x44, 0x36, 0x2a, + 0x41, 0x39, 0x3a, 0x41, 0x46, 0x3c, 0x44, 0x2f, 0x36, 0x39, 0x3b, 0x3f, + 0x38, 0x45, 0x3c, 0x3c, 0x3e, 0x41, 0x3c, 0x39, 0x3e, 0x40, 0x2f, 0x45, + 0x3b, 0x41, 0x40, 0x3c, 0x4e, 0x38, 0x3e, 0x48, 0x46, 0x40, 0x48, 0x44, + 0x40, 0x4a, 0x45, 0x3c, 0x4f, 0x39, 0x37, 0x3a, 0x4e, 0x59, 0x5c, 0x22, + 0x58, 0x32, 0x38, 0x34, 0x40, 0x4b, 0x43, 0x43, 0x4f, 0x3e, 0x39, 0x40, + 0x37, 0x3e, 0x2f, 0x55, 0x3f, 0x40, 0x38, 0x3f, 0x3a, 0x33, 0x37, 0x3d, + 0x34, 0x4c, 0x37, 0x3f, 0x32, 0x39, 0x45, 0x34, 0x44, 0x4c, 0x3f, 0x3b, + 0x3c, 0x36, 0x36, 0x43, 0x36, 0x47, 0x41, 0x46, 0x41, 0x3e, 0x41, 0x3a, + 0x43, 0x3a, 0x48, 0x42, 0x42, 0x3e, 0x4c, 0x36, 0x3d, 0x39, 0x43, 0x46, + 0x3d, 0x42, 0x42, 0x3b, 0x45, 0x43, 0x3c, 0x40, 0x39, 0x37, 0x34, 0x45, + 0x3f, 0x40, 0x34, 0x38, 0x43, 0x3f, 0x36, 0x47, 0x3f, 0x3b, 0x49, 0x3c, + 0x3a, 0x3a, 0x42, 0x4c, 0x37, 0x3e, 0x3b, 0x32, 0x47, 0x40, 0x45, 0x4d, + 0x39, 0x3b, 0x39, 0x40, 0x3e, 0x3c, 0x3d, 0x3a, 0x3d, 0x3b, 0x3e, 0x43, + 0x3e, 0x3f, 0x3a, 0x3c, 0x41, 0x40, 0x39, 0x3c, 0x3a, 0x38, 0x39, 0x37, + 0x36, 0x33, 0x43, 0x45, 0x3f, 0x45, 0x41, 0x30, 0x3b, 0x34, 0x3c, 0x39, + 0x3b, 0x45, 0x37, 0x2e, 0x36, 0x34, 0x36, 0x44, 0x3d, 0x40, 0x3a, 0x3c, + 0x3d, 0x3b, 0x38, 0x41, 0x42, 0x3a, 0x32, 0x4b, 0x38, 0x3e, 0x41, 0x46, + 0x57, 0x3a, 0x44, 0x48, 0x47, 0x45, 0x47, 0x3e, 0x43, 0x42, 0x45, 0x3b, + 0x50, 0x39, 0x37, 0x3f, 0x47, 0x51, 0x5e, 0x22, 0x59, 0x33, 0x3c, 0x37, + 0x43, 0x50, 0x49, 0x47, 0x46, 0x42, 0x39, 0x44, 0x44, 0x3d, 0x2f, 0x53, + 0x35, 0x41, 0x40, 0x3d, 0x2d, 0x35, 0x2f, 0x3e, 0x3f, 0x37, 0x38, 0x3e, + 0x30, 0x45, 0x46, 0x38, 0x33, 0x3c, 0x3e, 0x3b, 0x44, 0x42, 0x47, 0x49, + 0x43, 0x40, 0x3d, 0x3c, 0x38, 0x43, 0x3e, 0x38, 0x3d, 0x40, 0x36, 0x43, + 0x43, 0x3e, 0x40, 0x3c, 0x44, 0x47, 0x43, 0x3d, 0x41, 0x39, 0x3e, 0x45, + 0x39, 0x3d, 0x39, 0x40, 0x42, 0x40, 0x3b, 0x4a, 0x40, 0x41, 0x3f, 0x37, + 0x43, 0x41, 0x37, 0x4c, 0x3f, 0x3d, 0x38, 0x3a, 0x42, 0x46, 0x43, 0x4d, + 0x3c, 0x3a, 0x43, 0x3e, 0x3b, 0x3d, 0x46, 0x4a, 0x38, 0x3d, 0x3d, 0x39, + 0x3e, 0x3c, 0x3b, 0x3e, 0x3a, 0x40, 0x40, 0x34, 0x41, 0x3f, 0x3e, 0x3f, + 0x47, 0x3c, 0x32, 0x3a, 0x3c, 0x44, 0x3f, 0x42, 0x41, 0x43, 0x3e, 0x3a, + 0x3b, 0x42, 0x41, 0x39, 0x39, 0x37, 0x39, 0x3e, 0x3d, 0x33, 0x3e, 0x35, + 0x44, 0x37, 0x40, 0x35, 0x3f, 0x47, 0x37, 0x41, 0x35, 0x38, 0x47, 0x40, + 0x43, 0x44, 0x2e, 0x48, 0x35, 0x44, 0x41, 0x3c, 0x47, 0x3d, 0x3d, 0x52, + 0x48, 0x41, 0x44, 0x41, 0x42, 0x4b, 0x3e, 0x3d, 0x4e, 0x32, 0x34, 0x47, + 0x55, 0x57, 0x5f, 0x22, 0x57, 0x33, 0x40, 0x37, 0x40, 0x4a, 0x4d, 0x47, + 0x48, 0x38, 0x3e, 0x46, 0x37, 0x42, 0x28, 0x57, 0x38, 0x42, 0x36, 0x43, + 0x35, 0x37, 0x39, 0x39, 0x42, 0x39, 0x38, 0x3c, 0x35, 0x3c, 0x3c, 0x3a, + 0x3c, 0x4c, 0x45, 0x3f, 0x43, 0x3d, 0x45, 0x45, 0x40, 0x47, 0x3e, 0x3e, + 0x3d, 0x4b, 0x49, 0x35, 0x43, 0x3c, 0x36, 0x46, 0x3c, 0x46, 0x42, 0x44, + 0x3c, 0x42, 0x3d, 0x42, 0x44, 0x3c, 0x4a, 0x40, 0x40, 0x3c, 0x3b, 0x3c, + 0x35, 0x34, 0x2e, 0x46, 0x38, 0x3d, 0x38, 0x44, 0x41, 0x40, 0x3c, 0x52, + 0x3b, 0x3d, 0x3b, 0x3f, 0x42, 0x47, 0x44, 0x52, 0x44, 0x44, 0x39, 0x3f, + 0x43, 0x35, 0x3c, 0x4d, 0x39, 0x3d, 0x3b, 0x37, 0x3e, 0x38, 0x3e, 0x49, + 0x3a, 0x37, 0x3c, 0x49, 0x40, 0x41, 0x3c, 0x40, 0x3d, 0x38, 0x39, 0x3f, + 0x44, 0x3e, 0x42, 0x3e, 0x47, 0x40, 0x34, 0x46, 0x48, 0x37, 0x45, 0x3e, + 0x46, 0x3f, 0x35, 0x39, 0x38, 0x3f, 0x36, 0x2c, 0x40, 0x38, 0x3e, 0x3c, + 0x32, 0x3c, 0x46, 0x3a, 0x3f, 0x41, 0x36, 0x49, 0x42, 0x38, 0x36, 0x43, + 0x3d, 0x41, 0x46, 0x35, 0x4f, 0x3a, 0x41, 0x5c, 0x4a, 0x42, 0x4e, 0x42, + 0x46, 0x54, 0x3f, 0x45, 0x4c, 0x30, 0x33, 0x44, 0x56, 0x5d, 0x68, 0x26, + 0x60, 0x33, 0x3e, 0x3a, 0x42, 0x49, 0x52, 0x47, 0x51, 0x46, 0x40, 0x47, + 0x41, 0x3b, 0x1b, 0x4f, 0x3c, 0x45, 0x3d, 0x3d, 0x32, 0x2f, 0x3e, 0x3c, + 0x3c, 0x3f, 0x3b, 0x3c, 0x2c, 0x3a, 0x41, 0x3c, 0x35, 0x3e, 0x3e, 0x3c, + 0x3d, 0x3f, 0x3e, 0x40, 0x40, 0x44, 0x42, 0x3c, 0x3c, 0x3c, 0x41, 0x3c, + 0x3c, 0x3d, 0x3e, 0x3d, 0x3c, 0x3d, 0x4a, 0x46, 0x3f, 0x35, 0x33, 0x43, + 0x42, 0x41, 0x4d, 0x48, 0x48, 0x44, 0x3e, 0x41, 0x41, 0x36, 0x3c, 0x4c, + 0x34, 0x47, 0x42, 0x39, 0x3e, 0x43, 0x3a, 0x53, 0x3b, 0x3b, 0x42, 0x3d, + 0x41, 0x3c, 0x3e, 0x52, 0x3a, 0x44, 0x34, 0x43, 0x3d, 0x3d, 0x3a, 0x50, + 0x3e, 0x33, 0x41, 0x40, 0x3f, 0x38, 0x43, 0x42, 0x3b, 0x37, 0x3e, 0x43, + 0x3f, 0x3c, 0x41, 0x49, 0x40, 0x32, 0x40, 0x3e, 0x3b, 0x3e, 0x44, 0x3c, + 0x35, 0x37, 0x3d, 0x41, 0x34, 0x3f, 0x3a, 0x3c, 0x47, 0x32, 0x41, 0x3d, + 0x3c, 0x3a, 0x4a, 0x31, 0x43, 0x38, 0x45, 0x37, 0x49, 0x3c, 0x34, 0x3f, + 0x3d, 0x3d, 0x3d, 0x45, 0x47, 0x3e, 0x37, 0x48, 0x40, 0x3b, 0x45, 0x3d, + 0x4e, 0x42, 0x3f, 0x57, 0x4b, 0x43, 0x4b, 0x3d, 0x3f, 0x47, 0x4a, 0x43, + 0x4e, 0x30, 0x38, 0x45, 0x59, 0x60, 0x64, 0x2d, 0x5a, 0x2d, 0x34, 0x35, + 0x47, 0x54, 0x4e, 0x3f, 0x44, 0x45, 0x3c, 0x43, 0x3d, 0x40, 0x1c, 0x5a, + 0x36, 0x3f, 0x3a, 0x39, 0x37, 0x3c, 0x32, 0x3b, 0x2d, 0x4a, 0x42, 0x35, + 0x30, 0x41, 0x43, 0x3d, 0x3d, 0x45, 0x38, 0x36, 0x3e, 0x40, 0x3a, 0x4a, + 0x34, 0x3d, 0x44, 0x3c, 0x39, 0x3b, 0x52, 0x38, 0x40, 0x3b, 0x3f, 0x3f, + 0x35, 0x37, 0x46, 0x48, 0x38, 0x3b, 0x40, 0x36, 0x3d, 0x3a, 0x4f, 0x45, + 0x35, 0x3a, 0x35, 0x33, 0x37, 0x43, 0x42, 0x52, 0x37, 0x3b, 0x3d, 0x42, + 0x44, 0x3d, 0x48, 0x58, 0x33, 0x3f, 0x41, 0x44, 0x44, 0x3f, 0x3b, 0x52, + 0x47, 0x39, 0x32, 0x3b, 0x38, 0x35, 0x48, 0x50, 0x34, 0x30, 0x39, 0x43, + 0x42, 0x40, 0x3b, 0x4b, 0x43, 0x3d, 0x34, 0x44, 0x33, 0x39, 0x44, 0x4b, + 0x45, 0x3e, 0x3c, 0x3f, 0x3a, 0x3e, 0x3c, 0x45, 0x36, 0x3e, 0x3d, 0x40, + 0x43, 0x46, 0x37, 0x3d, 0x3b, 0x42, 0x43, 0x3f, 0x3a, 0x41, 0x48, 0x2f, + 0x3e, 0x39, 0x3a, 0x39, 0x3f, 0x3a, 0x41, 0x40, 0x40, 0x3c, 0x3b, 0x3b, + 0x3f, 0x40, 0x3e, 0x42, 0x38, 0x3f, 0x38, 0x3c, 0x49, 0x45, 0x3f, 0x62, + 0x55, 0x47, 0x4c, 0x3c, 0x3c, 0x4a, 0x4c, 0x46, 0x4f, 0x39, 0x3a, 0x3b, + 0x5e, 0x58, 0x6f, 0x2b, 0x5a, 0x2f, 0x3a, 0x35, 0x4b, 0x47, 0x4a, 0x46, + 0x45, 0x3e, 0x38, 0x4f, 0x3b, 0x3d, 0x21, 0x4b, 0x3d, 0x40, 0x37, 0x40, + 0x2d, 0x2c, 0x43, 0x3f, 0x2b, 0x3e, 0x3d, 0x39, 0x2f, 0x39, 0x44, 0x3c, + 0x39, 0x39, 0x43, 0x3b, 0x3d, 0x3b, 0x44, 0x39, 0x42, 0x42, 0x3e, 0x40, + 0x3b, 0x42, 0x53, 0x40, 0x32, 0x3d, 0x35, 0x3f, 0x3d, 0x45, 0x48, 0x46, + 0x3d, 0x43, 0x3c, 0x36, 0x35, 0x39, 0x3d, 0x4a, 0x39, 0x39, 0x3e, 0x41, + 0x38, 0x36, 0x3b, 0x53, 0x3c, 0x36, 0x32, 0x3b, 0x43, 0x3d, 0x42, 0x57, + 0x35, 0x2f, 0x38, 0x40, 0x2f, 0x3d, 0x3c, 0x4c, 0x40, 0x2f, 0x3a, 0x36, + 0x39, 0x3c, 0x3a, 0x51, 0x3d, 0x37, 0x39, 0x3c, 0x42, 0x40, 0x43, 0x52, + 0x3e, 0x42, 0x3e, 0x45, 0x36, 0x34, 0x42, 0x4b, 0x3a, 0x38, 0x37, 0x3f, + 0x36, 0x41, 0x3a, 0x45, 0x3e, 0x38, 0x35, 0x41, 0x35, 0x34, 0x37, 0x3c, + 0x3f, 0x31, 0x3c, 0x35, 0x33, 0x43, 0x36, 0x28, 0x44, 0x42, 0x3e, 0x42, + 0x3a, 0x41, 0x43, 0x35, 0x3d, 0x3f, 0x40, 0x3e, 0x3d, 0x33, 0x31, 0x41, + 0x3d, 0x40, 0x3b, 0x40, 0x51, 0x40, 0x3f, 0xfb, 0x51, 0x49, 0x4c, 0x3d, + 0x44, 0x4e, 0x47, 0x42, 0x50, 0x39, 0x39, 0x40, 0x59, 0x5d, 0x70, 0x2c, + 0x59, 0x39, 0x38, 0x2f, 0x46, 0x50, 0x51, 0x47, 0x4c, 0x3c, 0x39, 0x48, + 0x44, 0x3a, 0x1a, 0x51, 0x35, 0x3e, 0x34, 0x3a, 0x3d, 0x2b, 0x41, 0x39, + 0x37, 0x4d, 0x3e, 0x43, 0x38, 0x3b, 0x3a, 0x35, 0x36, 0x3a, 0x43, 0x39, + 0x39, 0x3a, 0x46, 0x3b, 0x39, 0x3c, 0x46, 0x36, 0x3e, 0x3d, 0x4b, 0x3d, + 0x3b, 0x46, 0x3a, 0x41, 0x31, 0x3c, 0x44, 0x4a, 0x37, 0x42, 0x39, 0x43, + 0x43, 0x3e, 0x40, 0x47, 0x3c, 0x3e, 0x3b, 0x43, 0x34, 0x3a, 0x43, 0x53, + 0x3f, 0x37, 0x39, 0x37, 0x3e, 0x3b, 0x46, 0x59, 0x37, 0x37, 0x33, 0x3d, + 0x38, 0x42, 0x36, 0x58, 0x2e, 0x32, 0x2b, 0x45, 0x32, 0x33, 0x36, 0x50, + 0x41, 0x3f, 0x37, 0x3d, 0x3f, 0x3d, 0x46, 0x49, 0x41, 0x38, 0x33, 0x3d, + 0x33, 0x32, 0x3a, 0x49, 0x41, 0x41, 0x3d, 0x33, 0x3b, 0x3b, 0x3a, 0x46, + 0x34, 0x44, 0x3f, 0x3b, 0x2f, 0x3f, 0x32, 0x3c, 0x3f, 0x43, 0x3e, 0x45, + 0x3a, 0x3c, 0x43, 0x26, 0x46, 0x37, 0x38, 0x3e, 0x36, 0x31, 0x3e, 0x34, + 0x39, 0x3a, 0x38, 0x42, 0x38, 0x3e, 0x32, 0x42, 0x37, 0x37, 0x3c, 0x3a, + 0x48, 0x44, 0x3a, 0x68, 0x56, 0x46, 0x4d, 0x47, 0x40, 0x4e, 0x42, 0x46, + 0x51, 0x40, 0x38, 0x43, 0x58, 0x5d, 0x6a, 0x31, 0x57, 0x32, 0x3c, 0x36, + 0x49, 0x56, 0x52, 0x48, 0x4b, 0x41, 0x2f, 0x4d, 0x31, 0x43, 0x1b, 0x4c, + 0x30, 0x44, 0x33, 0x36, 0x2c, 0x3d, 0x45, 0x3a, 0x35, 0x46, 0x3d, 0x39, + 0x2e, 0x38, 0x3f, 0x37, 0x41, 0x44, 0x46, 0x31, 0x33, 0x46, 0x37, 0x37, + 0x3f, 0x41, 0x45, 0x30, 0x46, 0x3b, 0x50, 0x3b, 0x40, 0x39, 0x42, 0x43, + 0x35, 0x37, 0x40, 0x44, 0x3b, 0x41, 0x3d, 0x37, 0x3a, 0x41, 0x3d, 0x46, + 0x36, 0x41, 0x38, 0x41, 0x38, 0x3d, 0x45, 0x58, 0x3d, 0x3a, 0x3d, 0x44, + 0x45, 0x38, 0x48, 0x5c, 0x3d, 0x39, 0x43, 0x45, 0x41, 0x3e, 0x4a, 0x56, + 0x40, 0x33, 0x30, 0x31, 0x42, 0x39, 0x38, 0x56, 0x30, 0x3a, 0x35, 0x3e, + 0x3f, 0x38, 0x36, 0x47, 0x3c, 0x3a, 0x3d, 0x3f, 0x37, 0x35, 0x3b, 0x4d, + 0x43, 0x36, 0x39, 0x37, 0x3e, 0x42, 0x3d, 0x3f, 0x40, 0x3f, 0x34, 0x3b, + 0x3f, 0x3e, 0x3b, 0x39, 0x3b, 0x3a, 0x3a, 0x3c, 0x34, 0x3f, 0x3c, 0x2a, + 0x49, 0x3b, 0x36, 0x3c, 0x35, 0x46, 0x38, 0x3b, 0x3c, 0x39, 0x38, 0x42, + 0x39, 0x36, 0x2e, 0x4a, 0x3d, 0x39, 0x3f, 0x3f, 0x4b, 0x45, 0x3e, 0x67, + 0x4b, 0x4b, 0x49, 0x3e, 0x3f, 0x53, 0x4c, 0x55, 0x47, 0x32, 0x3b, 0x39, + 0x54, 0x5b, 0x6f, 0x29, 0x5a, 0x34, 0x3e, 0x26, 0x45, 0x52, 0x59, 0x44, + 0x59, 0x39, 0x3c, 0x47, 0x36, 0x46, 0x16, 0x50, 0x32, 0x46, 0x34, 0x35, + 0x35, 0x2d, 0x39, 0x38, 0x2c, 0x42, 0x43, 0x3b, 0x32, 0x3f, 0x37, 0x2f, + 0x34, 0x43, 0x46, 0x3b, 0x3b, 0x41, 0x3c, 0x37, 0x3e, 0x43, 0x4b, 0x36, + 0x3e, 0x3c, 0x4c, 0x42, 0x40, 0x3f, 0x49, 0x40, 0x3c, 0x40, 0x3c, 0x48, + 0x35, 0x42, 0x3f, 0x42, 0x44, 0x40, 0x45, 0x4f, 0x3f, 0x3f, 0x40, 0x42, + 0x3b, 0x3d, 0x49, 0x55, 0x42, 0x39, 0x41, 0x3b, 0x3f, 0x38, 0x44, 0x60, + 0x34, 0x40, 0x3b, 0x3b, 0x35, 0x3d, 0x41, 0x4e, 0x35, 0x33, 0x30, 0x3a, + 0x3a, 0x32, 0x42, 0x4f, 0x33, 0x34, 0x2f, 0x38, 0x49, 0x38, 0x40, 0x4c, + 0x35, 0x38, 0x3e, 0x46, 0x3f, 0x3a, 0x3a, 0x45, 0x3b, 0x34, 0x2e, 0x39, + 0x32, 0x3e, 0x40, 0x48, 0x35, 0x44, 0x3a, 0x34, 0x3f, 0x35, 0x3b, 0x32, + 0x40, 0x43, 0x3e, 0x38, 0x3b, 0x43, 0x3c, 0x2b, 0x46, 0x43, 0x40, 0x32, + 0x42, 0x3b, 0x49, 0x2e, 0x3b, 0x3a, 0x3e, 0x41, 0x3c, 0x3f, 0x31, 0x3b, + 0x41, 0x33, 0x41, 0x3c, 0x4d, 0x40, 0x38, 0x68, 0x4c, 0x4c, 0x4e, 0x3f, + 0x3f, 0x54, 0x4a, 0x3d, 0x4c, 0x33, 0x3b, 0x3a, 0x5d, 0x60, 0x71, 0x2b, + 0x59, 0x33, 0x3c, 0x2c, 0x47, 0x52, 0x4f, 0x51, 0x56, 0x3d, 0x39, 0x44, + 0x35, 0x41, 0x1b, 0x4a, 0x35, 0x41, 0x37, 0x35, 0x2c, 0x35, 0x37, 0x35, + 0x38, 0x41, 0x38, 0x3e, 0x3c, 0x40, 0x3c, 0x2f, 0x38, 0x3e, 0x3f, 0x45, + 0x40, 0x3d, 0x3c, 0x35, 0x3c, 0x46, 0x43, 0x39, 0x37, 0x42, 0x4e, 0x3c, + 0x42, 0x46, 0x37, 0x33, 0x43, 0x3f, 0x47, 0x4a, 0x3d, 0x3e, 0x40, 0x40, + 0x40, 0x3f, 0x4b, 0x54, 0x36, 0x3f, 0x37, 0x40, 0x39, 0x39, 0x47, 0x51, + 0x3d, 0x39, 0x36, 0x36, 0x40, 0x40, 0x41, 0x5a, 0x38, 0x39, 0x42, 0x38, + 0x40, 0x39, 0x43, 0x50, 0x3a, 0x3a, 0x32, 0x3c, 0x3c, 0x35, 0x44, 0x4a, + 0x37, 0x35, 0x36, 0x3c, 0x35, 0x30, 0x48, 0x4b, 0x3c, 0x33, 0x37, 0x3e, + 0x42, 0x3c, 0x42, 0x4e, 0x41, 0x32, 0x3e, 0x33, 0x49, 0x39, 0x3e, 0x42, + 0x3d, 0x39, 0x37, 0x36, 0x35, 0x41, 0x3e, 0x37, 0x37, 0x3e, 0x3d, 0x38, + 0x3a, 0x3c, 0x41, 0x29, 0x3c, 0x3b, 0x39, 0x40, 0x43, 0x3d, 0x3e, 0x33, + 0x3f, 0x3f, 0x3e, 0x43, 0x43, 0x38, 0x38, 0x41, 0x3b, 0x38, 0x35, 0x3a, + 0x4b, 0x44, 0x44, 0x55, 0x4e, 0x44, 0x4d, 0x49, 0x3e, 0x53, 0x45, 0x3f, + 0x45, 0x3d, 0x36, 0x36, 0x4f, 0x5b, 0x6b, 0x28, 0x59, 0x34, 0x39, 0x34, + 0x4f, 0x4d, 0x52, 0x3e, 0x51, 0x34, 0x35, 0x4a, 0x3b, 0x3f, 0x21, 0x45, + 0x36, 0x3f, 0x38, 0x33, 0x2c, 0x37, 0x32, 0x2f, 0x2b, 0x44, 0x47, 0x3f, + 0x38, 0x3a, 0x3f, 0x2e, 0x41, 0x3f, 0x3d, 0x41, 0x35, 0x48, 0x43, 0x40, + 0x33, 0x44, 0x40, 0x38, 0x47, 0x44, 0x4c, 0x3d, 0x41, 0x3b, 0x39, 0x36, + 0x3e, 0x44, 0x49, 0x48, 0x3c, 0x3b, 0x34, 0x34, 0x3f, 0x3c, 0x42, 0x52, + 0x43, 0x41, 0x3c, 0x3c, 0x3d, 0x43, 0x48, 0x54, 0x39, 0x35, 0x39, 0x3c, + 0x43, 0x3c, 0x44, 0x5f, 0x39, 0x3d, 0x38, 0x3f, 0x36, 0x3d, 0x43, 0x58, + 0x33, 0x3d, 0x43, 0x33, 0x3f, 0x36, 0x39, 0x54, 0x3a, 0x37, 0x2d, 0x46, + 0x43, 0x41, 0x47, 0x46, 0x3e, 0x42, 0x34, 0x49, 0x3a, 0x3f, 0x38, 0x50, + 0x3a, 0x3b, 0x42, 0x3a, 0x3e, 0x3c, 0x3b, 0x40, 0x42, 0x45, 0x37, 0x3b, + 0x2f, 0x3b, 0x46, 0x30, 0x42, 0x3b, 0x3b, 0x44, 0x3b, 0x3e, 0x40, 0x1e, + 0x33, 0x40, 0x40, 0x3d, 0x39, 0x3a, 0x41, 0x33, 0x45, 0x3e, 0x3c, 0x3f, + 0x3f, 0x38, 0x31, 0x46, 0x3b, 0x35, 0x42, 0x39, 0x49, 0x3e, 0x3d, 0x66, + 0x53, 0x3f, 0x44, 0x40, 0x43, 0x45, 0x48, 0x45, 0x49, 0x2d, 0x3e, 0x3a, + 0x4f, 0x5a, 0x62, 0x27, 0x54, 0x37, 0x35, 0x34, 0x42, 0x50, 0x54, 0x43, + 0x4d, 0x38, 0x39, 0x48, 0x38, 0x4c, 0x21, 0x3f, 0x40, 0x3a, 0x3a, 0x2f, + 0x37, 0x2f, 0x29, 0x2c, 0x36, 0x47, 0x3f, 0x41, 0x31, 0x33, 0x3e, 0x32, + 0x3e, 0x40, 0x42, 0x40, 0x42, 0x3a, 0x46, 0x33, 0x44, 0x40, 0x3c, 0x43, + 0x3d, 0x41, 0x4d, 0x3d, 0x3c, 0x47, 0x46, 0x43, 0x42, 0x3e, 0x44, 0x4e, + 0x41, 0x3a, 0x44, 0x38, 0x45, 0x3b, 0x49, 0x4c, 0x40, 0x3f, 0x37, 0x3e, + 0x3e, 0x46, 0x41, 0x51, 0x3f, 0x39, 0x30, 0x40, 0x3e, 0x38, 0x43, 0x5b, + 0x33, 0x3e, 0x31, 0x42, 0x3d, 0x2f, 0x49, 0x57, 0x37, 0x31, 0x46, 0x44, + 0x3e, 0x35, 0x40, 0x55, 0x36, 0x35, 0x3d, 0x3c, 0x38, 0x33, 0x42, 0x52, + 0x3b, 0x39, 0x34, 0x31, 0x45, 0x34, 0x3c, 0x51, 0x33, 0x39, 0x3c, 0x40, + 0x36, 0x36, 0x42, 0x3e, 0x37, 0x3e, 0x3b, 0x40, 0x3d, 0x36, 0x41, 0x30, + 0x42, 0x45, 0x40, 0x49, 0x3d, 0x32, 0x46, 0x26, 0x40, 0x44, 0x3a, 0x3f, + 0x3d, 0x46, 0x45, 0x31, 0x33, 0x34, 0x3e, 0x37, 0x46, 0x3b, 0x32, 0x3a, + 0x3d, 0x31, 0x3c, 0x36, 0x50, 0x41, 0x3b, 0x5d, 0x53, 0x42, 0x44, 0x37, + 0x3e, 0x4d, 0x41, 0x4b, 0x49, 0x2f, 0x35, 0x3a, 0x4e, 0x59, 0x5d, 0x27, + 0x5c, 0x30, 0x3d, 0x3a, 0x46, 0x50, 0x57, 0x4a, 0x4c, 0x36, 0x37, 0x46, + 0x48, 0x41, 0x24, 0x49, 0x36, 0x3e, 0x41, 0x45, 0x37, 0x38, 0x2e, 0x2e, + 0x34, 0x3c, 0x38, 0x41, 0x36, 0x3d, 0x43, 0x36, 0x3e, 0x3e, 0x41, 0x3b, + 0x42, 0x3c, 0x43, 0x38, 0x3e, 0x3d, 0x41, 0x48, 0x47, 0x4c, 0x45, 0x3b, + 0x37, 0x41, 0x38, 0x41, 0x3d, 0x41, 0x46, 0x4e, 0x36, 0x45, 0x38, 0x39, + 0x42, 0x42, 0x37, 0x4c, 0x34, 0x46, 0x3c, 0x44, 0x4a, 0x39, 0x45, 0x53, + 0x3c, 0x3f, 0x41, 0x35, 0x3c, 0x45, 0x4c, 0x5a, 0x44, 0x41, 0x30, 0x35, + 0x40, 0x39, 0x42, 0x5a, 0x36, 0x36, 0x3a, 0x3b, 0x43, 0x35, 0x3c, 0x56, + 0x35, 0x38, 0x2b, 0x4a, 0x3c, 0x40, 0x45, 0x54, 0x37, 0x37, 0x3a, 0x44, + 0x42, 0x3b, 0x3d, 0x4a, 0x3f, 0x37, 0x3b, 0x35, 0x34, 0x3f, 0x40, 0x48, + 0x45, 0x3e, 0x37, 0x38, 0x41, 0x41, 0x3d, 0x37, 0x43, 0x3d, 0x3d, 0x45, + 0x3a, 0x38, 0x3f, 0x23, 0x4a, 0x37, 0x42, 0x3c, 0x3f, 0x43, 0x42, 0x33, + 0x37, 0x39, 0x35, 0x3b, 0x41, 0x36, 0x2f, 0x3b, 0x41, 0x3a, 0x44, 0x3d, + 0x3e, 0x45, 0x44, 0x50, 0x47, 0x47, 0x48, 0x3c, 0x3f, 0x45, 0x43, 0x3f, + 0x4a, 0x33, 0x3c, 0x3a, 0x52, 0x52, 0x5a, 0x23, 0x58, 0x31, 0x3b, 0x3b, + 0x47, 0x44, 0x54, 0x45, 0x42, 0x38, 0x38, 0x40, 0x43, 0x3f, 0x2a, 0x46, + 0x3b, 0x46, 0x3b, 0x46, 0x35, 0x37, 0x29, 0x35, 0x38, 0x41, 0x3a, 0x31, + 0x44, 0x41, 0x39, 0x36, 0x45, 0x41, 0x40, 0x3e, 0x40, 0x44, 0x47, 0x37, + 0x3f, 0x42, 0x49, 0x34, 0x46, 0x3d, 0x4b, 0x3d, 0x42, 0x3b, 0x42, 0x3e, + 0x41, 0x3b, 0x3f, 0x43, 0x47, 0x45, 0x47, 0x41, 0x40, 0x3a, 0x3d, 0x45, + 0x40, 0x36, 0x3b, 0x3b, 0x44, 0x37, 0x46, 0x55, 0x35, 0x42, 0x3f, 0x3a, + 0x41, 0x41, 0x44, 0x5c, 0x31, 0x44, 0x3d, 0x46, 0x39, 0x38, 0x46, 0x59, + 0x41, 0x3b, 0x3d, 0x39, 0x33, 0x3e, 0x41, 0x58, 0x33, 0x44, 0x34, 0x31, + 0x48, 0x3e, 0x4d, 0x56, 0x36, 0x3c, 0x37, 0x46, 0x46, 0x38, 0x45, 0x53, + 0x35, 0x3d, 0x3a, 0x31, 0x42, 0x48, 0x45, 0x44, 0x3b, 0x3b, 0x3c, 0x41, + 0x3d, 0x42, 0x3f, 0x2f, 0x38, 0x3c, 0x3e, 0x41, 0x44, 0x3a, 0x4a, 0x24, + 0x37, 0x3e, 0x37, 0x48, 0x40, 0x3f, 0x46, 0x3c, 0x47, 0x4a, 0x38, 0x47, + 0x34, 0x45, 0x31, 0x42, 0x43, 0x44, 0x3f, 0x3f, 0x49, 0x40, 0x3c, 0x41, + 0x4d, 0x43, 0x42, 0x39, 0x39, 0x48, 0x41, 0x38, 0x47, 0x3c, 0x3c, 0x42, + 0x44, 0x55, 0x62, 0x2a, 0x5c, 0x32, 0x3a, 0x37, 0x4c, 0x44, 0x4f, 0x3e, + 0x4e, 0x42, 0x3a, 0x42, 0x41, 0x4a, 0x35, 0x44, 0x45, 0x3b, 0x43, 0x41, + 0x33, 0x38, 0x28, 0x36, 0x40, 0x47, 0x3e, 0x3e, 0x3e, 0x39, 0x3a, 0x37, + 0x44, 0x44, 0x3f, 0x3b, 0x41, 0x3c, 0x45, 0x36, 0x38, 0x3a, 0x3c, 0x42, + 0x42, 0x3f, 0x59, 0x3c, 0x47, 0x3d, 0x38, 0x3a, 0x42, 0x44, 0x41, 0x46, + 0x3f, 0x43, 0x48, 0x42, 0x44, 0x35, 0x3f, 0x45, 0x36, 0x3f, 0x38, 0x3a, + 0x44, 0x3d, 0x3d, 0x4e, 0x3e, 0x45, 0x40, 0x42, 0x3c, 0x33, 0x43, 0x5a, + 0x38, 0x3e, 0x45, 0x3a, 0x3e, 0x42, 0x45, 0x52, 0x3c, 0x42, 0x3a, 0x38, + 0x3d, 0x3b, 0x4a, 0x57, 0x38, 0x37, 0x47, 0x44, 0x3e, 0x3c, 0x38, 0x48, + 0x36, 0x41, 0x3f, 0x41, 0x3a, 0x3a, 0x46, 0x47, 0x42, 0x40, 0x32, 0x33, + 0x43, 0x37, 0x41, 0x43, 0x3e, 0x40, 0x3d, 0x3a, 0x3e, 0x38, 0x42, 0x30, + 0x3e, 0x40, 0x46, 0x42, 0x40, 0x44, 0x42, 0x23, 0x31, 0x40, 0x3f, 0x3d, + 0x3b, 0x33, 0x40, 0x33, 0x41, 0x33, 0x43, 0x41, 0x3a, 0x3e, 0x36, 0x40, + 0x40, 0x45, 0x37, 0x42, 0x46, 0x42, 0x39, 0x48, 0x44, 0x40, 0x40, 0x45, + 0x3c, 0x49, 0x41, 0x3f, 0x4c, 0x3d, 0x2f, 0x3f, 0x47, 0x52, 0x54, 0x2c, + 0x55, 0x42, 0x44, 0x3b, 0x46, 0x4f, 0x48, 0x3c, 0x45, 0x39, 0x3f, 0x4b, + 0x3f, 0x3f, 0x36, 0x42, 0x41, 0x48, 0x44, 0x44, 0x36, 0x3b, 0x37, 0x40, + 0x39, 0x49, 0x3a, 0x35, 0x3e, 0x48, 0x31, 0x30, 0x44, 0x38, 0x4c, 0x3c, + 0x41, 0x3e, 0x46, 0x32, 0x44, 0x3b, 0x42, 0x3c, 0x38, 0x3a, 0x47, 0x3f, + 0x3a, 0x42, 0x3a, 0x43, 0x40, 0x4b, 0x47, 0x3c, 0x42, 0x46, 0x45, 0x42, + 0x3c, 0x46, 0x3d, 0x3f, 0x3e, 0x36, 0x38, 0x3e, 0x46, 0x3c, 0x4d, 0x43, + 0x49, 0x41, 0x48, 0x3c, 0x3d, 0x39, 0x43, 0x58, 0x3a, 0x41, 0x3f, 0x38, + 0x37, 0x3f, 0x46, 0x5d, 0x3c, 0x3c, 0x39, 0x36, 0x3d, 0x46, 0x43, 0x50, + 0x3a, 0x47, 0x39, 0x36, 0x41, 0x3f, 0x3e, 0x51, 0x31, 0x36, 0x3e, 0x3c, + 0x3c, 0x3a, 0x48, 0x41, 0x3a, 0x43, 0x49, 0x3e, 0x42, 0x46, 0x3f, 0x41, + 0x49, 0x33, 0x42, 0x41, 0x45, 0x40, 0x3d, 0x2b, 0x3d, 0x38, 0x40, 0x37, + 0x3a, 0x31, 0x45, 0x26, 0x33, 0x3d, 0x3f, 0x39, 0x36, 0x3c, 0x38, 0x33, + 0x34, 0x3f, 0x35, 0x44, 0x3a, 0x39, 0x32, 0x41, 0x35, 0x40, 0x3c, 0x3b, + 0x4a, 0x3f, 0x3e, 0x3e, 0x4a, 0x3e, 0x42, 0x35, 0x38, 0x43, 0x3c, 0x37, + 0x3d, 0x3c, 0x39, 0x43, 0x3f, 0x4e, 0x54, 0x33, 0x4b, 0x37, 0x43, 0x3b, + 0x43, 0x48, 0x43, 0x42, 0x3d, 0x46, 0x45, 0x49, 0x3a, 0x39, 0x36, 0x4a, + 0x48, 0x48, 0x37, 0x4b, 0x42, 0x47, 0x34, 0x34, 0x43, 0x42, 0x3a, 0x3d, + 0x3c, 0x46, 0x34, 0x39, 0x40, 0x3b, 0x3e, 0x3e, 0x37, 0x3d, 0x53, 0x3b, + 0x48, 0x3c, 0x43, 0x44, 0x3b, 0x3f, 0x57, 0x3d, 0x39, 0x3c, 0x39, 0x3a, + 0x3e, 0x3f, 0x43, 0x3e, 0x41, 0x47, 0x3c, 0x41, 0x40, 0x41, 0x37, 0x3f, + 0x3b, 0x43, 0x35, 0x3e, 0x45, 0x40, 0x47, 0x59, 0x41, 0x49, 0x3b, 0x3f, + 0x47, 0x49, 0x4b, 0x61, 0x39, 0x48, 0x39, 0x3e, 0x44, 0x34, 0x3b, 0x59, + 0x3c, 0x42, 0x45, 0x35, 0x42, 0x41, 0x39, 0x52, 0x42, 0x3c, 0x3d, 0x3e, + 0x3d, 0x4a, 0x4a, 0x4d, 0x3c, 0x34, 0x44, 0x3c, 0x41, 0x34, 0x38, 0x46, + 0x38, 0x45, 0x40, 0x45, 0x40, 0x3a, 0x3d, 0x44, 0x3a, 0x37, 0x3a, 0x3a, + 0x3b, 0x42, 0x40, 0x34, 0x3b, 0x3c, 0x42, 0x40, 0x3d, 0x32, 0x40, 0x27, + 0x37, 0x39, 0x37, 0x46, 0x48, 0x31, 0x40, 0x30, 0x42, 0x42, 0x3a, 0x40, + 0x3d, 0x37, 0x2a, 0x40, 0x41, 0x37, 0x3c, 0x4a, 0x46, 0x45, 0x3d, 0x34, + 0x48, 0x41, 0x42, 0x3e, 0x3f, 0x39, 0x3c, 0x3a, 0x4f, 0x3b, 0x32, 0x3e, + 0x43, 0x51, 0x4f, 0x2a, 0x46, 0x3a, 0x3d, 0x3b, 0x40, 0x3d, 0x4c, 0x3c, + 0x48, 0x40, 0x36, 0x4a, 0x3a, 0x38, 0x42, 0x43, 0x4c, 0x3d, 0x47, 0x47, + 0x33, 0x3f, 0x2d, 0x37, 0x4a, 0x43, 0x38, 0x3e, 0x49, 0x42, 0x42, 0x3d, + 0x43, 0x47, 0x41, 0x38, 0x46, 0x37, 0x46, 0x38, 0x47, 0x42, 0x49, 0x3d, + 0x3b, 0x37, 0x4c, 0x3c, 0x3a, 0x45, 0x3f, 0x37, 0x36, 0x3d, 0x3c, 0x40, + 0x3e, 0x45, 0x46, 0x41, 0x41, 0x40, 0x3c, 0x44, 0x47, 0x43, 0x37, 0x3f, + 0x3e, 0x3a, 0x3a, 0x4b, 0x3a, 0x36, 0x3d, 0x3f, 0x38, 0x3f, 0x3c, 0x58, + 0x40, 0x49, 0x3d, 0x42, 0x38, 0x3a, 0x47, 0x50, 0x3b, 0x49, 0x40, 0x44, + 0x3e, 0x3c, 0x38, 0x52, 0x3a, 0x3e, 0x44, 0x3c, 0x35, 0x44, 0x3a, 0x47, + 0x3e, 0x49, 0x3f, 0x47, 0x45, 0x39, 0x3b, 0x46, 0x44, 0x3e, 0x41, 0x46, + 0x40, 0x41, 0x40, 0x40, 0x3a, 0x35, 0x3e, 0x36, 0x3e, 0x3e, 0x3d, 0x35, + 0x3b, 0x3c, 0x38, 0x46, 0x3b, 0x3c, 0x41, 0x2c, 0x3f, 0x42, 0x38, 0x3b, + 0x36, 0x3b, 0x39, 0x40, 0x40, 0x38, 0x36, 0x33, 0x34, 0x42, 0x2f, 0x44, + 0x41, 0x40, 0x39, 0x35, 0x3b, 0x44, 0x42, 0x2c, 0x41, 0x3b, 0x44, 0x41, + 0x35, 0x44, 0x3b, 0x34, 0x44, 0x49, 0x36, 0x39, 0x3a, 0x52, 0x4d, 0x2b, + 0x40, 0x40, 0x3e, 0x39, 0x48, 0x42, 0x3c, 0x44, 0x46, 0x49, 0x3f, 0x54, + 0x43, 0x40, 0x2e, 0x40, 0x4f, 0x36, 0x3e, 0x3f, 0x38, 0x48, 0x44, 0x3c, + 0x44, 0x43, 0x41, 0x47, 0x40, 0x46, 0x40, 0x37, 0x41, 0x34, 0x3a, 0x41, + 0x41, 0x3b, 0x49, 0x39, 0x42, 0x38, 0x3d, 0x39, 0x34, 0x35, 0x43, 0x36, + 0x3e, 0x44, 0x3f, 0x40, 0x43, 0x40, 0x40, 0x3a, 0x47, 0x42, 0x3e, 0x42, + 0x46, 0x35, 0x3a, 0x46, 0x3c, 0x3c, 0x3c, 0x3d, 0x3f, 0x40, 0x43, 0x4c, + 0x3a, 0x37, 0x3f, 0x43, 0x47, 0x38, 0x42, 0x58, 0x42, 0x3b, 0x34, 0x37, + 0x3e, 0x48, 0x3c, 0x57, 0x44, 0x3c, 0x3d, 0x3a, 0x36, 0x48, 0x3c, 0x51, + 0x3d, 0x48, 0x45, 0x45, 0x38, 0x45, 0x40, 0x3f, 0x3b, 0x35, 0x3d, 0x3f, + 0x38, 0x47, 0x39, 0x3b, 0x36, 0x49, 0x43, 0x40, 0x3f, 0x46, 0x38, 0x40, + 0x3f, 0x3e, 0x39, 0x32, 0x47, 0x42, 0x35, 0x33, 0x39, 0x47, 0x3c, 0x36, + 0x3b, 0x37, 0x43, 0x35, 0x3b, 0x3b, 0x34, 0x3b, 0x38, 0x3d, 0x3e, 0x3a, + 0x35, 0x49, 0x38, 0x40, 0x3f, 0x3f, 0x3e, 0x37, 0x43, 0x3b, 0x3e, 0x3e, + 0x3b, 0x40, 0x44, 0x39, 0x3d, 0x3f, 0x31, 0x42, 0x42, 0x3b, 0x41, 0x3d, + 0x3e, 0x3c, 0x37, 0x34, 0x48, 0x3d, 0x49, 0x4a, 0x47, 0x36, 0x3a, 0x34, + 0x37, 0x36, 0x3e, 0x38, 0x33, 0x45, 0x39, 0x44, 0x34, 0x49, 0x3a, 0x3d, + 0x34, 0x31, 0x31, 0x3d, 0x34, 0x3d, 0x41, 0x3e, 0x49, 0x41, 0x34, 0x3f, + 0x3a, 0x42, 0x3e, 0x40, 0x3f, 0x33, 0x46, 0x3f, 0x34, 0x39, 0x37, 0x46, + 0x3e, 0x32, 0x3f, 0x45, 0x45, 0x41, 0x3b, 0x4b, 0x35, 0x35, 0x3b, 0x4a, + 0x3d, 0x43, 0x3b, 0x44, 0x3c, 0x38, 0x31, 0x43, 0x39, 0x35, 0x41, 0x45, + 0x37, 0x3e, 0x43, 0x47, 0x39, 0x40, 0x41, 0x41, 0x40, 0x32, 0x37, 0x3e, + 0x3d, 0x39, 0x3b, 0x49, 0x33, 0x35, 0x38, 0x41, 0x45, 0x37, 0x3c, 0x49, + 0x3b, 0x34, 0x34, 0x41, 0x3a, 0x3f, 0x3e, 0x47, 0x39, 0x3c, 0x34, 0x3a, + 0x38, 0x44, 0x40, 0x51, 0x3a, 0x37, 0x3b, 0x3f, 0x3d, 0x3a, 0x45, 0x48, + 0x3f, 0x46, 0x35, 0x43, 0x38, 0x43, 0x35, 0x4c, 0x42, 0x47, 0x44, 0x3d, + 0x40, 0x3a, 0x39, 0x4e, 0x3d, 0x37, 0x3c, 0x42, 0x40, 0x48, 0x44, 0x4c, + 0x31, 0x40, 0x42, 0x3b, 0x45, 0x45, 0x3f, 0x3e, 0x3d, 0x44, 0x3f, 0x31, + 0x3f, 0x44, 0x45, 0x37, 0x3e, 0x3d, 0x35, 0x3b, 0x2d, 0x44, 0x4a, 0x3a, + 0x2b, 0x37, 0x38, 0x46, 0x41, 0x39, 0x3c, 0x3c, 0x46, 0x33, 0x36, 0x3c, + 0x4b, 0x34, 0x49, 0x50, 0x30, 0x3c, 0x33, 0x41, 0x44, 0x33, 0x43, 0x39, + 0x36, 0x45, 0x33, 0x3b, 0x3d, 0x36, 0x47, 0x30, 0x42, 0x37, 0x49, 0x3e, + 0x3b, 0x49, 0x3d, 0x3b, 0x3a, 0x41, 0x38, 0x44, 0x42, 0x3b, 0x3f, 0x40, + 0x46, 0x35, 0x38, 0x3c, 0x48, 0x3a, 0x46, 0x41, 0x36, 0x36, 0x41, 0x3e, + 0x43, 0x3e, 0x32, 0x39, 0x3a, 0x41, 0x30, 0x3e, 0x40, 0x3e, 0x36, 0x3a, + 0x45, 0x45, 0x3a, 0x3c, 0x31, 0x3b, 0x47, 0x3f, 0x36, 0x3a, 0x3c, 0x41, + 0x3b, 0x41, 0x39, 0x46, 0x3f, 0x3c, 0x34, 0x3e, 0x41, 0x45, 0x41, 0x42, + 0x39, 0x40, 0x40, 0x44, 0x45, 0x42, 0x34, 0x3f, 0x3e, 0x31, 0x3b, 0x41, + 0x33, 0x43, 0x37, 0x44, 0x44, 0x3a, 0x36, 0x36, 0x48, 0x3c, 0x37, 0x47, + 0x39, 0x3e, 0x3e, 0x3c, 0x3c, 0x41, 0x3c, 0x44, 0x3b, 0x42, 0x3f, 0x3a, + 0x43, 0x3b, 0x3e, 0x48, 0x36, 0x3f, 0x3d, 0x34, 0x40, 0x43, 0x35, 0x4f, + 0x34, 0x39, 0x3b, 0x41, 0x40, 0x39, 0x37, 0x4c, 0x39, 0x36, 0x39, 0x39, + 0x47, 0x41, 0x43, 0x3f, 0x3f, 0x33, 0x42, 0x3f, 0x42, 0x40, 0x37, 0x40, + 0x3f, 0x34, 0x45, 0x3d, 0x2d, 0x3c, 0x44, 0x3b, 0x43, 0x37, 0x26, 0x50, + 0x43, 0x44, 0x3d, 0x43, 0x42, 0x2d, 0x3c, 0x33, 0x4a, 0x32, 0x4a, 0x53, + 0x33, 0x38, 0x27, 0x36, 0x42, 0x30, 0x47, 0x3d, 0x36, 0x45, 0x46, 0x36, + 0x3b, 0x3b, 0x40, 0x33, 0x37, 0x36, 0x44, 0x46, 0x3d, 0x35, 0x40, 0x38, + 0x3b, 0x40, 0x36, 0x3c, 0x3d, 0x37, 0x31, 0x41, 0x33, 0x3c, 0x38, 0x3f, + 0x43, 0x3a, 0x40, 0x49, 0x38, 0x39, 0x38, 0x3d, 0x43, 0x3d, 0x39, 0x3b, + 0x3d, 0x3f, 0x38, 0x42, 0x34, 0x43, 0x33, 0x3e, 0x43, 0x3e, 0x40, 0x42, + 0x3b, 0x45, 0x37, 0x44, 0x43, 0x39, 0x3c, 0x3d, 0x37, 0x44, 0x3a, 0x3b, + 0x47, 0x3f, 0x3a, 0x3c, 0x3a, 0x3b, 0x3f, 0x43, 0x3e, 0x3d, 0x46, 0x3e, + 0x37, 0x36, 0x3f, 0x40, 0x42, 0x42, 0x37, 0x36, 0x48, 0x35, 0x44, 0x44, + 0x39, 0x3c, 0x3b, 0x41, 0x44, 0x49, 0x3a, 0x40, 0x41, 0x36, 0x33, 0x3a, + 0x3c, 0x3d, 0x40, 0x3f, 0x43, 0x36, 0x3c, 0x3a, 0x3f, 0x4b, 0x32, 0x49, + 0x49, 0x3e, 0x3a, 0x3e, 0x3f, 0x41, 0x3c, 0x47, 0x40, 0x41, 0x45, 0x3e, + 0x47, 0x47, 0x3f, 0x47, 0x45, 0x3e, 0x31, 0x43, 0x4a, 0x44, 0x36, 0x40, + 0x41, 0x47, 0x3e, 0x42, 0x37, 0x40, 0x3b, 0x46, 0x37, 0x41, 0x3e, 0x3c, + 0x27, 0x40, 0x49, 0x42, 0x42, 0x39, 0x30, 0x49, 0x43, 0x38, 0x3d, 0x42, + 0x43, 0x2f, 0x3b, 0x37, 0x4b, 0x2d, 0x4f, 0x52, 0x30, 0x31, 0x2f, 0x3a, + 0x49, 0x38, 0x4f, 0x45, 0x2e, 0x47, 0x3a, 0x32, 0x33, 0x3f, 0x4a, 0x2e, + 0x33, 0x3b, 0x3e, 0x3e, 0x49, 0x45, 0x44, 0x38, 0x3c, 0x35, 0x45, 0x47, + 0x41, 0x3b, 0x3c, 0x48, 0x46, 0x39, 0x39, 0x3b, 0x3f, 0x41, 0x38, 0x42, + 0x3d, 0x46, 0x33, 0x41, 0x36, 0x3f, 0x3f, 0x3c, 0x33, 0x3e, 0x3e, 0x40, + 0x44, 0x40, 0x3c, 0x38, 0x46, 0x3a, 0x40, 0x36, 0x42, 0x35, 0x3f, 0x3b, + 0x3b, 0x43, 0x3c, 0x40, 0x40, 0x49, 0x2e, 0x39, 0x40, 0x3f, 0x45, 0x41, + 0x3f, 0x30, 0x42, 0x3d, 0x40, 0x3c, 0x3a, 0x3b, 0x3b, 0x40, 0x39, 0x42, + 0x3a, 0x3f, 0x3f, 0x3e, 0x35, 0x3b, 0x38, 0x45, 0x47, 0x35, 0x44, 0x3e, + 0x3b, 0x3f, 0x3f, 0x40, 0x3a, 0x35, 0x30, 0x49, 0x45, 0x35, 0x3b, 0x39, + 0x3b, 0x48, 0x3f, 0x37, 0x39, 0x40, 0x43, 0x45, 0x3d, 0x40, 0x41, 0x3a, + 0x33, 0x3d, 0x3a, 0x4b, 0x40, 0x42, 0x40, 0x42, 0x43, 0x39, 0x3c, 0x49, + 0x3e, 0x47, 0x3e, 0x44, 0x3f, 0x3a, 0x40, 0x41, 0x3f, 0x42, 0x42, 0x37, + 0x3e, 0x3b, 0x36, 0x3e, 0x3b, 0x3c, 0x48, 0x43, 0x2d, 0x46, 0x4a, 0x38, + 0x45, 0x3a, 0x29, 0x46, 0x40, 0x3c, 0x40, 0x44, 0x40, 0x33, 0x2f, 0x33, + 0x48, 0x2e, 0x51, 0x4f, 0x2b, 0x32, 0x2e, 0x2d, 0x45, 0x33, 0x4d, 0x41, + 0x29, 0x4b, 0x41, 0x39, 0x2f, 0x3a, 0x49, 0x31, 0x37, 0x40, 0x47, 0x4c, + 0x3e, 0x31, 0x41, 0x3f, 0x43, 0x37, 0x45, 0x4f, 0x41, 0x3c, 0x30, 0x4a, + 0x37, 0x37, 0x36, 0x39, 0x31, 0x3d, 0x36, 0x4b, 0x37, 0x44, 0x3c, 0x43, + 0x44, 0x36, 0x3f, 0x3b, 0x34, 0x3e, 0x3a, 0x35, 0x38, 0x3f, 0x33, 0x37, + 0x3b, 0x3d, 0x46, 0x38, 0x3b, 0x37, 0x38, 0x3b, 0x31, 0x3e, 0x3d, 0x3b, + 0x3d, 0x39, 0x35, 0x33, 0x33, 0x3c, 0x39, 0x39, 0x48, 0x39, 0x39, 0x3f, + 0x3e, 0x36, 0x47, 0x3a, 0x44, 0x40, 0x32, 0x3c, 0x37, 0x35, 0x40, 0x3f, + 0x3a, 0x38, 0x3b, 0x3d, 0x46, 0x45, 0x36, 0x43, 0x40, 0x3d, 0x41, 0x41, + 0x47, 0x3a, 0x3d, 0x3e, 0x43, 0x42, 0x32, 0x36, 0x41, 0x37, 0x3b, 0x35, + 0x36, 0x44, 0x36, 0x3c, 0x43, 0x32, 0x3e, 0x3e, 0x42, 0x45, 0x32, 0x3c, + 0x3a, 0x3b, 0x35, 0x43, 0x41, 0x3d, 0x44, 0x50, 0x43, 0x31, 0x3e, 0x44, + 0x44, 0x41, 0x3a, 0x44, 0x36, 0x39, 0x3b, 0x3c, 0x32, 0x38, 0x3b, 0x45, + 0x38, 0x43, 0x40, 0x42, 0x33, 0x3e, 0x4a, 0x42, 0x45, 0x39, 0x2f, 0x42, + 0x39, 0x35, 0x44, 0x3e, 0x39, 0x2f, 0x34, 0x33, 0x49, 0x29, 0x50, 0x4f, + 0x2b, 0x36, 0x34, 0x2d, 0x47, 0x33, 0x49, 0x3c, 0x33, 0x51, 0x49, 0x3f, + 0x34, 0x39, 0x4a, 0x2c, 0x34, 0x45, 0x4f, 0x47, 0x34, 0x42, 0x3a, 0x3d, + 0x36, 0x4a, 0x3b, 0x43, 0x36, 0x3f, 0x39, 0x4b, 0x38, 0x3a, 0x31, 0x3d, + 0x32, 0x42, 0x3a, 0x47, 0x48, 0x3e, 0x44, 0x3f, 0x39, 0x3e, 0x44, 0x35, + 0x41, 0x3c, 0x45, 0x3a, 0x3e, 0x3b, 0x3d, 0x2f, 0x37, 0x40, 0x3e, 0x43, + 0x39, 0x39, 0x33, 0x3b, 0x37, 0x3b, 0x37, 0x37, 0x37, 0x39, 0x36, 0x31, + 0x39, 0x3b, 0x41, 0x39, 0x3b, 0x40, 0x36, 0x37, 0x42, 0x39, 0x3a, 0x46, + 0x3f, 0x30, 0x38, 0x39, 0x35, 0x32, 0x3e, 0x3a, 0x43, 0x43, 0x3e, 0x33, + 0x42, 0x3f, 0x41, 0x3c, 0x46, 0x34, 0x34, 0x40, 0x43, 0x37, 0x32, 0x43, + 0x3c, 0x37, 0x36, 0x33, 0x3d, 0x36, 0x3a, 0x40, 0x39, 0x38, 0x32, 0x3e, + 0x32, 0x3d, 0x37, 0x49, 0x42, 0x47, 0x41, 0x3b, 0x3d, 0x3c, 0x3a, 0x37, + 0x3c, 0x45, 0x3a, 0x45, 0x36, 0x44, 0x3a, 0x3a, 0x3a, 0x3c, 0x43, 0x3b, + 0x3b, 0x35, 0x38, 0x47, 0x36, 0x40, 0x32, 0x43, 0x3e, 0x39, 0x42, 0x40, + 0x2c, 0x3c, 0x4c, 0x4c, 0x43, 0x3b, 0x37, 0x4a, 0x3f, 0x3c, 0x45, 0x44, + 0x3f, 0x30, 0x36, 0x31, 0x4f, 0x2f, 0x5d, 0x4b, 0x34, 0x34, 0x2d, 0x2b, + 0x44, 0x31, 0x4e, 0x40, 0x2e, 0x4d, 0x48, 0x3e, 0x37, 0x2b, 0x49, 0x25, + 0x31, 0x49, 0x44, 0x49, 0x39, 0x39, 0x4b, 0x3a, 0x3a, 0x41, 0x3e, 0x42, + 0x3c, 0x36, 0x36, 0x4a, 0x32, 0x44, 0x3e, 0x48, 0x3e, 0x3c, 0x37, 0x49, + 0x3d, 0x34, 0x3f, 0x37, 0x33, 0x36, 0x46, 0x3a, 0x3a, 0x31, 0x45, 0x3f, + 0x3a, 0x31, 0x3b, 0x33, 0x41, 0x42, 0x35, 0x39, 0x38, 0x44, 0x36, 0x3a, + 0x3f, 0x3b, 0x37, 0x3e, 0x3b, 0x38, 0x2f, 0x32, 0x44, 0x3d, 0x44, 0x41, + 0x39, 0x36, 0x3a, 0x34, 0x39, 0x38, 0x34, 0x3f, 0x3b, 0x37, 0x34, 0x34, + 0x40, 0x3d, 0x34, 0x3a, 0x46, 0x42, 0x3f, 0x34, 0x38, 0x33, 0x39, 0x44, + 0x3f, 0x41, 0x3c, 0x31, 0x40, 0x32, 0x3f, 0x37, 0x37, 0x41, 0x3e, 0x35, + 0x37, 0x48, 0x3b, 0x41, 0x3d, 0x3a, 0x3f, 0x35, 0x33, 0x3c, 0x36, 0x3b, + 0x3a, 0x48, 0x33, 0x42, 0x37, 0x33, 0x39, 0x41, 0x3c, 0x3d, 0x3b, 0x4d, + 0x39, 0x3a, 0x3e, 0x44, 0x3d, 0x41, 0x3b, 0x38, 0x49, 0x41, 0x3a, 0x38, + 0x34, 0x38, 0x38, 0x3c, 0x45, 0x3c, 0x37, 0x3b, 0x36, 0x3e, 0x4a, 0x4b, + 0x42, 0x3f, 0x32, 0x45, 0x46, 0x35, 0x46, 0x41, 0x38, 0x33, 0x39, 0x37, + 0x44, 0x2b, 0x60, 0x4a, 0x2a, 0x2e, 0x35, 0x2d, 0x43, 0x37, 0x51, 0x47, + 0x2f, 0x4d, 0x50, 0x3e, 0x3a, 0x33, 0x4f, 0x2a, 0x35, 0x45, 0x4a, 0x4c, + 0x3b, 0x3d, 0x43, 0x44, 0x3d, 0x3f, 0x4a, 0x3e, 0x49, 0x37, 0x2e, 0x4f, + 0x39, 0x3f, 0x32, 0x3c, 0x37, 0x3b, 0x39, 0x4d, 0x34, 0x3f, 0x46, 0x44, + 0x3d, 0x40, 0x3f, 0x40, 0x39, 0x33, 0x39, 0x3e, 0x3d, 0x40, 0x31, 0x30, + 0x35, 0x3d, 0x3e, 0x3a, 0x3e, 0x32, 0x31, 0x3e, 0x48, 0x3c, 0x40, 0x43, + 0x3f, 0x3f, 0x34, 0x2e, 0x3a, 0x3e, 0x3b, 0x43, 0x45, 0x32, 0x3a, 0x31, + 0x37, 0x38, 0x31, 0x35, 0x34, 0x3d, 0x42, 0x36, 0x46, 0x37, 0x32, 0x47, + 0x41, 0x3c, 0x35, 0x35, 0x36, 0x41, 0x3a, 0x3b, 0x42, 0x44, 0x36, 0x31, + 0x3c, 0x3d, 0x34, 0x34, 0x3b, 0x40, 0x40, 0x2e, 0x40, 0x46, 0x3b, 0x43, + 0x3f, 0x40, 0x3b, 0x3a, 0x32, 0x40, 0x46, 0x39, 0x3c, 0x49, 0x2f, 0x3d, + 0x49, 0x3e, 0x44, 0x3c, 0x3e, 0x35, 0x3f, 0x44, 0x41, 0x40, 0x3e, 0x47, + 0x3d, 0x40, 0x3f, 0x41, 0x3b, 0x41, 0x41, 0x3f, 0x40, 0x3f, 0x3e, 0x3e, + 0x3f, 0x43, 0x35, 0x40, 0x2b, 0x42, 0x45, 0x56, 0x40, 0x3c, 0x2f, 0x44, + 0x44, 0x3d, 0x3e, 0x3d, 0x40, 0x2d, 0x39, 0x31, 0x54, 0x2f, 0x61, 0x48, + 0x2e, 0x37, 0x37, 0x32, 0x3e, 0x2d, 0x52, 0x4d, 0x2d, 0x4d, 0x4c, 0x3a, + 0x3a, 0x31, 0x4e, 0x2d, 0x31, 0x48, 0x47, 0x54, 0x45, 0x38, 0x3b, 0x3d, + 0x42, 0x41, 0x44, 0x4a, 0x48, 0x42, 0x2f, 0x4d, 0x31, 0x34, 0x3a, 0x46, + 0x37, 0x44, 0x2c, 0x45, 0x46, 0x43, 0x40, 0x3f, 0x34, 0x33, 0x40, 0x39, + 0x32, 0x35, 0x3a, 0x40, 0x3f, 0x3f, 0x36, 0x32, 0x3f, 0x3d, 0x35, 0x48, + 0x3c, 0x48, 0x37, 0x39, 0x35, 0x3f, 0x37, 0x3d, 0x44, 0x46, 0x2d, 0x2a, + 0x47, 0x38, 0x3a, 0x39, 0x45, 0x3b, 0x40, 0x2d, 0x37, 0x33, 0x41, 0x3c, + 0x40, 0x35, 0x3f, 0x32, 0x3a, 0x36, 0x40, 0x41, 0x3a, 0x3c, 0x33, 0x31, + 0x42, 0x3f, 0x41, 0x3a, 0x41, 0x46, 0x38, 0x2f, 0x3c, 0x3d, 0x3d, 0x39, + 0x3b, 0x46, 0x41, 0x31, 0x46, 0x36, 0x40, 0x48, 0x3c, 0x33, 0x42, 0x32, + 0x3b, 0x40, 0x3f, 0x36, 0x37, 0x44, 0x34, 0x35, 0x32, 0x32, 0x37, 0x38, + 0x33, 0x3b, 0x37, 0x4a, 0x3f, 0x46, 0x3a, 0x41, 0x32, 0x37, 0x30, 0x3e, + 0x40, 0x35, 0x41, 0x40, 0x37, 0x41, 0x2b, 0x40, 0x3d, 0x3d, 0x32, 0x38, + 0x34, 0x3e, 0x47, 0x61, 0x43, 0x3b, 0x3c, 0x42, 0x46, 0x3d, 0x40, 0x4a, + 0x3c, 0x2d, 0x33, 0x35, 0x55, 0x38, 0x69, 0x4f, 0x33, 0x37, 0x30, 0x39, + 0x44, 0x2e, 0x58, 0x4b, 0x2a, 0x51, 0x4b, 0x3c, 0x39, 0x2e, 0x51, 0x2d, + 0x30, 0x4a, 0x42, 0x53, 0x3f, 0x39, 0x3e, 0x44, 0x3b, 0x40, 0x47, 0x44, + 0x47, 0x3e, 0x39, 0x4b, 0x40, 0x3d, 0x42, 0x39, 0x3b, 0x39, 0x32, 0x42, + 0x36, 0x36, 0x36, 0x42, 0x44, 0x34, 0x33, 0x40, 0x40, 0x40, 0x3a, 0x3a, + 0x41, 0x3f, 0x31, 0x30, 0x3f, 0x31, 0x30, 0x39, 0x46, 0x36, 0x35, 0x34, + 0x40, 0x43, 0x3c, 0x41, 0x31, 0x46, 0x35, 0x26, 0x44, 0x32, 0x3d, 0x35, + 0x3d, 0x3c, 0x36, 0x32, 0x39, 0x3a, 0x30, 0x40, 0x48, 0x3e, 0x38, 0x37, + 0x44, 0x3b, 0x3d, 0x42, 0x3d, 0x3c, 0x32, 0x2b, 0x3f, 0x41, 0x39, 0x3d, + 0x3e, 0x3f, 0x35, 0x2f, 0x46, 0x3d, 0x3d, 0x3b, 0x45, 0x37, 0x31, 0x35, + 0x44, 0x40, 0x3a, 0x45, 0x3a, 0x3c, 0x39, 0x31, 0x3b, 0x3d, 0x3b, 0x3a, + 0x43, 0x44, 0x39, 0x47, 0x44, 0x36, 0x3e, 0x39, 0x48, 0x3f, 0x39, 0x4b, + 0x3c, 0x36, 0x3d, 0x44, 0x44, 0x3f, 0x39, 0x43, 0x3f, 0x37, 0x3f, 0x37, + 0x3b, 0x3b, 0x38, 0x3b, 0x3f, 0x40, 0x31, 0x44, 0x30, 0x44, 0x46, 0x5b, + 0x46, 0x3f, 0x39, 0x40, 0x40, 0x37, 0x4a, 0x46, 0x3f, 0x36, 0x40, 0x39, + 0x59, 0x3e, 0x66, 0x57, 0x32, 0x34, 0x2e, 0x33, 0x46, 0x31, 0x58, 0x44, + 0x26, 0x4c, 0x4b, 0x3c, 0x39, 0x2e, 0x4d, 0x35, 0x32, 0x46, 0x52, 0x52, + 0x3e, 0x40, 0x39, 0x3c, 0x39, 0x3d, 0x53, 0x48, 0x41, 0x3c, 0x3b, 0x4d, + 0x3c, 0x3e, 0x38, 0x44, 0x3a, 0x3a, 0x29, 0x4a, 0x3c, 0x37, 0x36, 0x38, + 0x3a, 0x31, 0x37, 0x39, 0x3a, 0x40, 0x46, 0x32, 0x42, 0x38, 0x32, 0x2e, + 0x3a, 0x45, 0x44, 0x34, 0x34, 0x38, 0x32, 0x2e, 0x35, 0x40, 0x3a, 0x41, + 0x42, 0x3d, 0x37, 0x2c, 0x3f, 0x37, 0x3c, 0x3d, 0x3a, 0x36, 0x33, 0x35, + 0x3c, 0x34, 0x3c, 0x39, 0x3c, 0x3a, 0x37, 0x30, 0x30, 0x3e, 0x3d, 0x3a, + 0x44, 0x37, 0x36, 0x32, 0x36, 0x37, 0x36, 0x3a, 0x3c, 0x41, 0x3a, 0x35, + 0x36, 0x3a, 0x34, 0x40, 0x39, 0x40, 0x3e, 0x32, 0x34, 0x46, 0x33, 0x3f, + 0x36, 0x45, 0x3e, 0x35, 0x3f, 0x38, 0x3f, 0x3e, 0x3b, 0x3a, 0x36, 0x3b, + 0x36, 0x38, 0x32, 0x3f, 0x44, 0x3c, 0x35, 0x48, 0x38, 0x39, 0x31, 0x49, + 0x3d, 0x43, 0x36, 0x3f, 0x31, 0x43, 0x36, 0x3e, 0x3e, 0x41, 0x39, 0x3b, + 0x40, 0x42, 0x3c, 0x43, 0x36, 0x4a, 0x48, 0x67, 0x4e, 0x43, 0x36, 0x46, + 0x44, 0x3f, 0x4b, 0x4b, 0x3f, 0x38, 0x3c, 0x3c, 0x5e, 0x38, 0x70, 0x52, + 0x38, 0x32, 0x3b, 0x36, 0x4a, 0x2c, 0x52, 0x46, 0x29, 0x4f, 0x48, 0x42, + 0x2d, 0x2e, 0x4f, 0x28, 0x28, 0x45, 0x4d, 0x52, 0x42, 0x3e, 0x3f, 0x41, + 0x3c, 0x3a, 0x47, 0x50, 0x44, 0x45, 0x33, 0x4b, 0x3e, 0x3f, 0x42, 0x3d, + 0x43, 0x34, 0x27, 0x3f, 0x42, 0x3e, 0x43, 0x3e, 0x3a, 0x3c, 0x37, 0x3b, + 0x3f, 0x30, 0x3a, 0x3e, 0x3c, 0x34, 0x37, 0x24, 0x3d, 0x43, 0x40, 0x44, + 0x40, 0x46, 0x31, 0x2f, 0x43, 0x38, 0x38, 0x39, 0x3c, 0x34, 0x2d, 0x2a, + 0x38, 0x31, 0x43, 0x3b, 0x39, 0x3b, 0x32, 0x34, 0x3e, 0x39, 0x41, 0x3b, + 0x3e, 0x33, 0x3a, 0x2a, 0x41, 0x3f, 0x3c, 0x43, 0x3b, 0x3e, 0x35, 0x2c, + 0x38, 0x41, 0x33, 0x31, 0x3e, 0x3f, 0x3a, 0x3c, 0x3b, 0x35, 0x3f, 0x3d, + 0x42, 0x3a, 0x3c, 0x35, 0x3f, 0x40, 0x3c, 0x3e, 0x37, 0x41, 0x3d, 0x38, + 0x34, 0x31, 0x36, 0x3d, 0x3d, 0x47, 0x36, 0x44, 0x3f, 0x45, 0x3c, 0x3c, + 0x35, 0x36, 0x31, 0x4f, 0x46, 0x3a, 0x41, 0x42, 0x40, 0x32, 0x33, 0x41, + 0x34, 0x40, 0x3d, 0x43, 0x3b, 0x3a, 0x32, 0x3c, 0x42, 0x42, 0x3d, 0x43, + 0x37, 0x45, 0x45, 0xff, 0x4b, 0x45, 0x3b, 0x40, 0x43, 0x3e, 0x47, 0x49, + 0x3d, 0x3b, 0x3e, 0x33, 0x58, 0x35, 0x71, 0x54, 0x2f, 0x38, 0x38, 0x33, + 0x47, 0x35, 0x5b, 0x46, 0x2c, 0x4c, 0x43, 0x37, 0x36, 0x39, 0x4f, 0x30, + 0x26, 0x48, 0x51, 0x48, 0x46, 0x45, 0x3b, 0x39, 0x42, 0x50, 0x47, 0x4c, + 0x4b, 0x3b, 0x3d, 0x4d, 0x41, 0x34, 0x40, 0x44, 0x38, 0x32, 0x2d, 0x43, + 0x39, 0x36, 0x3b, 0x3b, 0x40, 0x3d, 0x37, 0x3c, 0x44, 0x39, 0x42, 0x37, + 0x38, 0x38, 0x32, 0x2f, 0x41, 0x40, 0x3f, 0x3a, 0x37, 0x35, 0x3b, 0x2a, + 0x37, 0x30, 0x3c, 0x37, 0x40, 0x38, 0x3a, 0x27, 0x44, 0x3d, 0x43, 0x40, + 0x35, 0x3f, 0x3e, 0x32, 0x3e, 0x3c, 0x40, 0x39, 0x39, 0x3a, 0x41, 0x31, + 0x3b, 0x3f, 0x34, 0x43, 0x3a, 0x38, 0x42, 0x2a, 0x47, 0x46, 0x3b, 0x38, + 0x47, 0x45, 0x39, 0x31, 0x43, 0x40, 0x37, 0x3a, 0x3d, 0x3e, 0x39, 0x30, + 0x36, 0x37, 0x3a, 0x43, 0x3f, 0x32, 0x31, 0x41, 0x45, 0x3e, 0x43, 0x38, + 0x3f, 0x37, 0x3c, 0x49, 0x3b, 0x33, 0x3d, 0x3a, 0x37, 0x44, 0x32, 0x50, + 0x39, 0x44, 0x3e, 0x3f, 0x3d, 0x41, 0x3e, 0x3e, 0x42, 0x44, 0x45, 0x3f, + 0x36, 0x3f, 0x37, 0x39, 0x3b, 0x3d, 0x3b, 0x3b, 0x2f, 0x46, 0x40, 0x6d, + 0x50, 0x45, 0x3b, 0x45, 0x46, 0x3b, 0x42, 0x48, 0x42, 0x3c, 0x39, 0x37, + 0x57, 0x3b, 0x6c, 0x5b, 0x32, 0x35, 0x3d, 0x39, 0x48, 0x31, 0x5c, 0x46, + 0x29, 0x4c, 0x3f, 0x3e, 0x37, 0x33, 0x58, 0x32, 0x2a, 0x43, 0x4c, 0x50, + 0x3b, 0x44, 0x3c, 0x41, 0x39, 0x48, 0x55, 0x4c, 0x42, 0x38, 0x3b, 0x51, + 0x3f, 0x38, 0x44, 0x46, 0x36, 0x3b, 0x38, 0x4a, 0x3f, 0x37, 0x36, 0x3c, + 0x31, 0x3d, 0x32, 0x39, 0x3b, 0x3f, 0x3e, 0x35, 0x38, 0x3f, 0x34, 0x2b, + 0x37, 0x36, 0x39, 0x40, 0x37, 0x41, 0x32, 0x27, 0x36, 0x33, 0x40, 0x3a, + 0x3f, 0x44, 0x3f, 0x25, 0x38, 0x34, 0x42, 0x3c, 0x3a, 0x40, 0x38, 0x31, + 0x49, 0x3e, 0x33, 0x3d, 0x31, 0x36, 0x39, 0x2b, 0x44, 0x2f, 0x43, 0x34, + 0x34, 0x37, 0x39, 0x33, 0x3b, 0x34, 0x42, 0x3c, 0x40, 0x45, 0x36, 0x31, + 0x43, 0x47, 0x3e, 0x3f, 0x40, 0x3a, 0x33, 0x34, 0x41, 0x44, 0x3a, 0x43, + 0x3e, 0x38, 0x36, 0x31, 0x42, 0x44, 0x40, 0x41, 0x44, 0x43, 0x33, 0x42, + 0x3d, 0x41, 0x3d, 0x3e, 0x3c, 0x39, 0x3e, 0x4f, 0x3f, 0x37, 0x31, 0x40, + 0x3b, 0x38, 0x35, 0x3b, 0x44, 0x41, 0x41, 0x37, 0x40, 0x42, 0x2d, 0x3d, + 0x39, 0x48, 0x44, 0x3e, 0x34, 0x48, 0x49, 0x6d, 0x45, 0x4b, 0x3a, 0x44, + 0x49, 0x40, 0x4d, 0x51, 0x3f, 0x34, 0x3b, 0x40, 0x52, 0x34, 0x6f, 0x56, + 0x33, 0x3e, 0x40, 0x39, 0x41, 0x32, 0x5d, 0x45, 0x2e, 0x51, 0x48, 0x3c, + 0x2e, 0x2e, 0x51, 0x39, 0x32, 0x45, 0x4a, 0x4c, 0x3b, 0x40, 0x40, 0x3b, + 0x36, 0x41, 0x54, 0x4e, 0x4a, 0x49, 0x3b, 0x4d, 0x3c, 0x41, 0x38, 0x47, + 0x3d, 0x3c, 0x37, 0x48, 0x3f, 0x42, 0x3e, 0x36, 0x39, 0x46, 0x37, 0x3e, + 0x3b, 0x38, 0x40, 0x3b, 0x39, 0x32, 0x3e, 0x29, 0x37, 0x35, 0x3c, 0x3d, + 0x37, 0x3b, 0x35, 0x2f, 0x32, 0x3b, 0x37, 0x3c, 0x40, 0x3e, 0x39, 0x27, + 0x3b, 0x38, 0x37, 0x36, 0x39, 0x37, 0x37, 0x35, 0x42, 0x3e, 0x3b, 0x43, + 0x41, 0x3c, 0x37, 0x2a, 0x3a, 0x3e, 0x38, 0x40, 0x36, 0x3e, 0x44, 0x2e, + 0x3e, 0x3a, 0x37, 0x3b, 0x3e, 0x41, 0x3d, 0x30, 0x3b, 0x3f, 0x41, 0x45, + 0x3a, 0x48, 0x37, 0x2f, 0x3a, 0x37, 0x34, 0x43, 0x42, 0x3d, 0x38, 0x41, + 0x3b, 0x3c, 0x39, 0x3c, 0x39, 0x47, 0x2e, 0x41, 0x42, 0x40, 0x32, 0x36, + 0x43, 0x40, 0x3d, 0x4c, 0x38, 0x3e, 0x3b, 0x41, 0x3d, 0x3b, 0x34, 0x43, + 0x43, 0x3f, 0x44, 0x3c, 0x3a, 0x33, 0x39, 0x42, 0x43, 0x3f, 0x33, 0x3d, + 0x33, 0x3e, 0x48, 0x6b, 0x48, 0x43, 0x36, 0x47, 0x49, 0x44, 0x4a, 0x49, + 0x3c, 0x31, 0x35, 0x3e, 0x5c, 0x34, 0x73, 0x53, 0x33, 0x3c, 0x32, 0x3b, + 0x43, 0x27, 0x59, 0x4e, 0x2b, 0x51, 0x4f, 0x37, 0x36, 0x34, 0x56, 0x34, + 0x32, 0x4f, 0x46, 0x50, 0x40, 0x40, 0x3c, 0x3e, 0x34, 0x37, 0x50, 0x49, + 0x43, 0x47, 0x3e, 0x52, 0x44, 0x38, 0x3b, 0x4f, 0x3a, 0x3d, 0x2b, 0x4c, + 0x40, 0x38, 0x3a, 0x35, 0x3a, 0x3a, 0x3d, 0x38, 0x3d, 0x3b, 0x37, 0x48, + 0x3d, 0x3d, 0x32, 0x30, 0x3a, 0x34, 0x3f, 0x3a, 0x3b, 0x3e, 0x35, 0x2f, + 0x3b, 0x3a, 0x45, 0x3d, 0x42, 0x33, 0x33, 0x24, 0x44, 0x39, 0x3c, 0x3d, + 0x41, 0x3c, 0x37, 0x2c, 0x3b, 0x36, 0x34, 0x41, 0x3d, 0x3f, 0x39, 0x32, + 0x3c, 0x40, 0x44, 0x3d, 0x41, 0x3d, 0x3a, 0x29, 0x3e, 0x3e, 0x43, 0x33, + 0x3f, 0x3e, 0x3e, 0x31, 0x38, 0x3a, 0x34, 0x3d, 0x3f, 0x3e, 0x3a, 0x3d, + 0x3e, 0x48, 0x45, 0x3d, 0x44, 0x37, 0x33, 0x3d, 0x45, 0x39, 0x40, 0x40, + 0x42, 0x3f, 0x3f, 0x3d, 0x3a, 0x3b, 0x41, 0x33, 0x41, 0x3c, 0x32, 0x55, + 0x43, 0x3a, 0x32, 0x40, 0x3c, 0x3e, 0x40, 0x43, 0x37, 0x3f, 0x40, 0x38, + 0x43, 0x41, 0x36, 0x42, 0x44, 0x3c, 0x32, 0x3f, 0x38, 0x42, 0x46, 0x59, + 0x4c, 0x41, 0x39, 0x47, 0x46, 0x46, 0x44, 0x44, 0x35, 0x42, 0x32, 0x39, + 0x4f, 0x34, 0x6d, 0x55, 0x31, 0x3b, 0x3a, 0x3f, 0x44, 0x2c, 0x5d, 0x43, + 0x26, 0x4a, 0x4f, 0x40, 0x36, 0x32, 0x4d, 0x33, 0x2f, 0x50, 0x4d, 0x57, + 0x3b, 0x40, 0x42, 0x44, 0x41, 0x3f, 0x52, 0x4e, 0x35, 0x41, 0x44, 0x52, + 0x40, 0x35, 0x39, 0x4b, 0x45, 0x34, 0x2c, 0x4a, 0x3b, 0x41, 0x31, 0x33, + 0x3f, 0x3a, 0x36, 0x3c, 0x3c, 0x33, 0x30, 0x38, 0x43, 0x3f, 0x32, 0x2d, + 0x3f, 0x3a, 0x38, 0x41, 0x39, 0x45, 0x36, 0x2e, 0x3c, 0x38, 0x45, 0x3f, + 0x40, 0x3f, 0x3e, 0x26, 0x41, 0x37, 0x3c, 0x44, 0x3f, 0x3f, 0x35, 0x37, + 0x46, 0x34, 0x37, 0x3e, 0x48, 0x38, 0x36, 0x34, 0x33, 0x39, 0x40, 0x3c, + 0x42, 0x3d, 0x3b, 0x31, 0x38, 0x3b, 0x44, 0x42, 0x45, 0x38, 0x41, 0x30, + 0x3d, 0x42, 0x36, 0x3f, 0x3b, 0x45, 0x37, 0x32, 0x3c, 0x37, 0x3d, 0x42, + 0x38, 0x3d, 0x2f, 0x31, 0x39, 0x40, 0x3f, 0x44, 0x3a, 0x41, 0x44, 0x46, + 0x3d, 0x3a, 0x32, 0x3b, 0x34, 0x47, 0x36, 0x4c, 0x47, 0x35, 0x3c, 0x33, + 0x3b, 0x3c, 0x30, 0x43, 0x43, 0x3f, 0x31, 0x40, 0x3a, 0x37, 0x30, 0x46, + 0x39, 0x3b, 0x42, 0x40, 0x2d, 0x3f, 0x3e, 0x6a, 0x50, 0x3b, 0x31, 0x54, + 0x47, 0x3d, 0x48, 0x4e, 0x3b, 0x41, 0x3a, 0x39, 0x49, 0x36, 0x64, 0x4e, + 0x32, 0x39, 0x3d, 0x37, 0x42, 0x2c, 0x5c, 0x43, 0x2a, 0x4b, 0x4b, 0x46, + 0x30, 0x29, 0x52, 0x31, 0x35, 0x44, 0x4a, 0x4b, 0x3d, 0x3b, 0x4e, 0x42, + 0x3d, 0x39, 0x42, 0x52, 0x3f, 0x36, 0x3e, 0x50, 0x3f, 0x32, 0x35, 0x3a, + 0x40, 0x39, 0x35, 0x48, 0x3b, 0x3e, 0x41, 0x43, 0x43, 0x45, 0x2f, 0x36, + 0x38, 0x34, 0x3f, 0x44, 0x32, 0x3f, 0x37, 0x33, 0x33, 0x35, 0x2e, 0x41, + 0x37, 0x3e, 0x38, 0x28, 0x49, 0x30, 0x46, 0x39, 0x3b, 0x30, 0x38, 0x28, + 0x3b, 0x3d, 0x3a, 0x43, 0x3f, 0x34, 0x43, 0x36, 0x39, 0x3c, 0x3e, 0x3e, + 0x39, 0x3b, 0x39, 0x32, 0x3c, 0x36, 0x3e, 0x38, 0x34, 0x3c, 0x3a, 0x2a, + 0x46, 0x3d, 0x40, 0x37, 0x3b, 0x39, 0x3b, 0x34, 0x38, 0x31, 0x43, 0x46, + 0x3b, 0x43, 0x39, 0x2b, 0x38, 0x40, 0x3e, 0x39, 0x35, 0x3d, 0x2c, 0x36, + 0x37, 0x40, 0x36, 0x40, 0x41, 0x38, 0x32, 0x3f, 0x36, 0x46, 0x34, 0x31, + 0x40, 0x3e, 0x3c, 0x4e, 0x42, 0x3d, 0x36, 0x3f, 0x42, 0x3f, 0x33, 0x40, + 0x34, 0x37, 0x3c, 0x3b, 0x31, 0x47, 0x32, 0x3c, 0x34, 0x3d, 0x42, 0x3b, + 0x37, 0x41, 0x3b, 0x64, 0x52, 0x40, 0x36, 0x4e, 0x46, 0x3f, 0x3f, 0x47, + 0x3c, 0x3a, 0x3a, 0x41, 0x4a, 0x32, 0x5e, 0x50, 0x2d, 0x39, 0x3a, 0x38, + 0x3d, 0x2c, 0x5a, 0x3e, 0x2e, 0x47, 0x3e, 0x3e, 0x33, 0x29, 0x4c, 0x35, + 0x30, 0x4d, 0x4d, 0x4d, 0x38, 0x42, 0x51, 0x47, 0x39, 0x3c, 0x43, 0x4b, + 0x42, 0x3f, 0x3a, 0x4b, 0x44, 0x3f, 0x3a, 0x44, 0x3e, 0x37, 0x30, 0x45, + 0x3d, 0x36, 0x34, 0x3f, 0x36, 0x35, 0x37, 0x36, 0x43, 0x3b, 0x37, 0x3e, + 0x35, 0x3e, 0x32, 0x34, 0x32, 0x38, 0x3c, 0x3a, 0x3a, 0x3c, 0x30, 0x2b, + 0x31, 0x37, 0x30, 0x42, 0x36, 0x37, 0x36, 0x2c, 0x3c, 0x31, 0x41, 0x37, + 0x44, 0x41, 0x3b, 0x37, 0x41, 0x3f, 0x38, 0x3b, 0x3a, 0x3a, 0x3c, 0x2f, + 0x47, 0x41, 0x3e, 0x33, 0x42, 0x3a, 0x32, 0x34, 0x44, 0x40, 0x43, 0x3d, + 0x34, 0x41, 0x38, 0x35, 0x35, 0x3b, 0x45, 0x38, 0x32, 0x37, 0x3c, 0x2e, + 0x39, 0x40, 0x30, 0x3e, 0x42, 0x35, 0x3d, 0x36, 0x3e, 0x3d, 0x39, 0x46, + 0x3f, 0x36, 0x37, 0x49, 0x41, 0x39, 0x3d, 0x3d, 0x33, 0x44, 0x42, 0x50, + 0x3d, 0x3c, 0x3e, 0x3f, 0x42, 0x42, 0x3b, 0x3d, 0x41, 0x31, 0x39, 0x3a, + 0x44, 0x34, 0x38, 0x47, 0x44, 0x38, 0x3b, 0x42, 0x30, 0x42, 0x44, 0x57, + 0x49, 0x3a, 0x39, 0x4f, 0x41, 0x3e, 0x40, 0x43, 0x37, 0x42, 0x3b, 0x48, + 0x50, 0x29, 0x5b, 0x44, 0x2c, 0x40, 0x3f, 0x3c, 0x46, 0x34, 0x5c, 0x41, + 0x2c, 0x48, 0x46, 0x46, 0x35, 0x32, 0x4c, 0x35, 0x2f, 0x3b, 0x48, 0x44, + 0x41, 0x41, 0x49, 0x45, 0x34, 0x37, 0x44, 0x45, 0x43, 0x3b, 0x42, 0x44, + 0x3a, 0x37, 0x48, 0x49, 0x34, 0x39, 0x33, 0x4a, 0x40, 0x3d, 0x33, 0x39, + 0x39, 0x3b, 0x30, 0x31, 0x3d, 0x47, 0x3c, 0x3a, 0x34, 0x3c, 0x3a, 0x2b, + 0x3a, 0x34, 0x41, 0x40, 0x42, 0x36, 0x44, 0x2c, 0x40, 0x47, 0x3b, 0x37, + 0x38, 0x42, 0x44, 0x29, 0x36, 0x3d, 0x3d, 0x36, 0x42, 0x3b, 0x35, 0x36, + 0x43, 0x39, 0x41, 0x3d, 0x45, 0x41, 0x31, 0x32, 0x40, 0x3d, 0x3c, 0x41, + 0x3e, 0x3d, 0x35, 0x34, 0x32, 0x38, 0x36, 0x3f, 0x3b, 0x3d, 0x39, 0x36, + 0x40, 0x3e, 0x3d, 0x3a, 0x3a, 0x3b, 0x3c, 0x32, 0x40, 0x34, 0x3a, 0x36, + 0x42, 0x47, 0x3e, 0x33, 0x3a, 0x44, 0x30, 0x39, 0x40, 0x3a, 0x36, 0x44, + 0x3c, 0x3b, 0x3f, 0x33, 0x3e, 0x3c, 0x35, 0x53, 0x43, 0x3c, 0x3f, 0x43, + 0x3d, 0x44, 0x33, 0x47, 0x42, 0x40, 0x37, 0x3b, 0x43, 0x3f, 0x33, 0x41, + 0x38, 0x42, 0x44, 0x3d, 0x2d, 0x3f, 0x46, 0x49, 0x4e, 0x3f, 0x36, 0x45, + 0x45, 0x39, 0x40, 0x42, 0x39, 0x39, 0x3a, 0x42, 0x45, 0x2c, 0x61, 0x44, + 0x30, 0x45, 0x38, 0x3a, 0x40, 0x37, 0x58, 0x39, 0x31, 0x3e, 0x3a, 0x3e, + 0x37, 0x32, 0x4a, 0x39, 0x2e, 0x47, 0x3e, 0x4e, 0x3f, 0x3e, 0x48, 0x45, + 0x3f, 0x48, 0x3a, 0x3f, 0x40, 0x36, 0x3a, 0x44, 0x36, 0x3e, 0x3d, 0x41, + 0x45, 0x36, 0x36, 0x4b, 0x3a, 0x3d, 0x45, 0x48, 0x38, 0x45, 0x39, 0x38, + 0x38, 0x3a, 0x42, 0x34, 0x3f, 0x34, 0x39, 0x34, 0x32, 0x3f, 0x3c, 0x3d, + 0x3d, 0x47, 0x3a, 0x2f, 0x3c, 0x3e, 0x3f, 0x39, 0x35, 0x42, 0x3c, 0x2a, + 0x3b, 0x35, 0x42, 0x44, 0x46, 0x39, 0x38, 0x39, 0x43, 0x3a, 0x38, 0x42, + 0x3d, 0x3a, 0x40, 0x35, 0x34, 0x39, 0x3a, 0x38, 0x43, 0x42, 0x42, 0x2d, + 0x31, 0x3b, 0x33, 0x40, 0x3b, 0x47, 0x35, 0x30, 0x3a, 0x3c, 0x3b, 0x47, + 0x3a, 0x3c, 0x38, 0x35, 0x3c, 0x35, 0x3e, 0x3e, 0x39, 0x3d, 0x39, 0x40, + 0x37, 0x33, 0x49, 0x38, 0x3c, 0x43, 0x34, 0x40, 0x39, 0x42, 0x3c, 0x3b, + 0x3e, 0x45, 0x3e, 0x51, 0x3d, 0x3f, 0x3b, 0x34, 0x37, 0x3c, 0x40, 0x47, + 0x3c, 0x41, 0x3f, 0x41, 0x37, 0x3e, 0x36, 0x3c, 0x42, 0x40, 0x3f, 0x3a, + 0x3b, 0x42, 0x44, 0x4b, 0x4b, 0x37, 0x41, 0x4d, 0x41, 0x45, 0x40, 0x41, + 0x40, 0x38, 0x37, 0x40, 0x42, 0x2c, 0x57, 0x43, 0x2d, 0x49, 0x3a, 0x3e, + 0x37, 0x2f, 0x52, 0x37, 0x31, 0x42, 0x3b, 0x3f, 0x39, 0x38, 0x48, 0x3c, + 0x37, 0x3d, 0x3a, 0x39, 0x3a, 0x45, 0x4b, 0x49, 0x3e, 0x44, 0x48, 0x49, + 0x3d, 0x39, 0x3c, 0x41, 0x41, 0x38, 0x45, 0x38, 0x33, 0x3d, 0x37, 0x47, + 0x34, 0x3f, 0x3b, 0x3d, 0x39, 0x34, 0x30, 0x39, 0x44, 0x36, 0x34, 0x3c, + 0x37, 0x38, 0x45, 0x34, 0x40, 0x33, 0x41, 0x3a, 0x3e, 0x3c, 0x3b, 0x3a, + 0x40, 0x3f, 0x3b, 0x3d, 0x3b, 0x46, 0x41, 0x2a, 0x3a, 0x3c, 0x42, 0x46, + 0x33, 0x3f, 0x2d, 0x3a, 0x45, 0x45, 0x38, 0x3b, 0x44, 0x34, 0x35, 0x3f, + 0x34, 0x43, 0x38, 0x3e, 0x41, 0x3b, 0x42, 0x38, 0x3d, 0x3f, 0x38, 0x45, + 0x3b, 0x35, 0x39, 0x3c, 0x43, 0x43, 0x38, 0x34, 0x44, 0x43, 0x2e, 0x39, + 0x39, 0x40, 0x39, 0x41, 0x41, 0x34, 0x3e, 0x44, 0x3d, 0x43, 0x3a, 0x3a, + 0x3b, 0x3b, 0x36, 0x45, 0x3c, 0x43, 0x3d, 0x48, 0x36, 0x36, 0x39, 0x55, + 0x35, 0x40, 0x3e, 0x49, 0x40, 0x3a, 0x3d, 0x3d, 0x34, 0x47, 0x40, 0x41, + 0x40, 0x47, 0x39, 0x3e, 0x3b, 0x38, 0x3c, 0x3a, 0x35, 0x3e, 0x41, 0x4a, + 0x4b, 0x3f, 0x36, 0x3d, 0x40, 0x3c, 0x39, 0x32, 0x33, 0x36, 0x30, 0x42, + 0x42, 0x36, 0x54, 0x48, 0x2e, 0x4c, 0x34, 0x3c, 0x39, 0x36, 0x4e, 0x37, + 0x2f, 0x3e, 0x30, 0x3d, 0x36, 0x3b, 0x45, 0x36, 0x37, 0x3e, 0x41, 0x4b, + 0x3b, 0x36, 0x45, 0x3b, 0x38, 0x45, 0x3e, 0x43, 0x48, 0x46, 0x44, 0x44, + 0x3e, 0x3b, 0x37, 0x3b, 0x3a, 0x3f, 0x3d, 0x44, 0x39, 0x38, 0x45, 0x43, + 0x3d, 0x35, 0x39, 0x2c, 0x44, 0x41, 0x36, 0x40, 0x3d, 0x39, 0x3d, 0x2f, + 0x3d, 0x39, 0x42, 0x3d, 0x36, 0x46, 0x43, 0x2c, 0x41, 0x3a, 0x30, 0x45, + 0x3f, 0x41, 0x35, 0x2b, 0x3b, 0x38, 0x3a, 0x44, 0x32, 0x32, 0x39, 0x3c, + 0x3a, 0x3a, 0x3c, 0x3a, 0x35, 0x40, 0x3b, 0x31, 0x36, 0x33, 0x35, 0x34, + 0x3c, 0x3b, 0x3d, 0x36, 0x48, 0x3b, 0x3f, 0x42, 0x3e, 0x33, 0x2f, 0x3a, + 0x49, 0x41, 0x39, 0x3e, 0x3c, 0x44, 0x3c, 0x39, 0x33, 0x39, 0x36, 0x35, + 0x3d, 0x42, 0x34, 0x3e, 0x38, 0x45, 0x40, 0x45, 0x3d, 0x48, 0x42, 0x4a, + 0x3f, 0x45, 0x38, 0x42, 0x44, 0x40, 0x34, 0x49, 0x44, 0x3d, 0x3a, 0x39, + 0x3e, 0x3a, 0x42, 0x3e, 0x48, 0x42, 0x3e, 0x3a, 0x3f, 0x3f, 0x32, 0x3b, + 0x38, 0x41, 0x3c, 0x39, 0x33, 0x45, 0x44, 0x3c, 0x48, 0x41, 0x41, 0x3d, + 0x3a, 0x3c, 0x37, 0x33, 0x41, 0x3f, 0x38, 0x3a, 0x3f, 0x37, 0x51, 0x3c, + 0x37, 0x3a, 0x43, 0x37, 0x40, 0x31, 0x4f, 0x34, 0x3b, 0x44, 0x45, 0x39, + 0x40, 0x33, 0x49, 0x33, 0x3e, 0x35, 0x44, 0x3d, 0x3b, 0x3f, 0x43, 0x41, + 0x43, 0x43, 0x48, 0x44, 0x46, 0x3b, 0x43, 0x3f, 0x3c, 0x3f, 0x3e, 0x3d, + 0x3b, 0x41, 0x3c, 0x43, 0x30, 0x34, 0x39, 0x33, 0x3f, 0x38, 0x36, 0x2e, + 0x33, 0x3f, 0x3c, 0x40, 0x3d, 0x3b, 0x3b, 0x31, 0x36, 0x41, 0x3b, 0x38, + 0x46, 0x36, 0x34, 0x31, 0x42, 0x44, 0x33, 0x35, 0x3f, 0x36, 0x3c, 0x30, + 0x3f, 0x31, 0x39, 0x3e, 0x3f, 0x47, 0x3e, 0x34, 0x36, 0x36, 0x34, 0x39, + 0x37, 0x46, 0x40, 0x33, 0x3b, 0x3a, 0x3f, 0x41, 0x37, 0x44, 0x3a, 0x3f, + 0x34, 0x45, 0x37, 0x33, 0x3f, 0x47, 0x41, 0x36, 0x39, 0x3e, 0x40, 0x38, + 0x41, 0x3d, 0x3d, 0x36, 0x40, 0x3a, 0x3b, 0x3b, 0x41, 0x3b, 0x3a, 0x3f, + 0x3f, 0x3b, 0x35, 0x42, 0x46, 0x3a, 0x30, 0x45, 0x40, 0x37, 0x39, 0x39, + 0x3d, 0x38, 0x3f, 0x45, 0x3f, 0x31, 0x32, 0x3b, 0x35, 0x3e, 0x3b, 0x38, + 0x3b, 0x44, 0x37, 0x39, 0x37, 0x42, 0x3f, 0x44, 0x38, 0x36, 0x37, 0x44, + 0x45, 0x46, 0x41, 0x3b, 0x46, 0x42, 0x43, 0x43, 0x3a, 0x4b, 0x37, 0x35, + 0x3b, 0x40, 0x32, 0x38, 0x41, 0x38, 0x4f, 0x3e, 0x36, 0x3f, 0x47, 0x3b, + 0x47, 0x3b, 0x4a, 0x2e, 0x3d, 0x45, 0x3b, 0x46, 0x3e, 0x38, 0x43, 0x38, + 0x41, 0x48, 0x3a, 0x39, 0x40, 0x45, 0x3b, 0x43, 0x40, 0x3e, 0x43, 0x41, + 0x41, 0x3e, 0x39, 0x3f, 0x35, 0x42, 0x33, 0x3f, 0x3d, 0x32, 0x45, 0x3c, + 0x41, 0x31, 0x45, 0x38, 0x43, 0x45, 0x41, 0x35, 0x35, 0x40, 0x44, 0x36, + 0x3a, 0x3b, 0x3c, 0x2c, 0x3e, 0x41, 0x33, 0x3d, 0x46, 0x34, 0x3b, 0x30, + 0x30, 0x42, 0x43, 0x3d, 0x3d, 0x3d, 0x43, 0x31, 0x3f, 0x40, 0x3a, 0x3f, + 0x48, 0x3e, 0x3b, 0x39, 0x44, 0x43, 0x3b, 0x3a, 0x42, 0x38, 0x38, 0x3b, + 0x3f, 0x44, 0x37, 0x3e, 0x45, 0x40, 0x41, 0x3b, 0x3c, 0x3a, 0x38, 0x37, + 0x3b, 0x33, 0x3f, 0x35, 0x43, 0x3d, 0x33, 0x41, 0x3b, 0x46, 0x39, 0x32, + 0x39, 0x3f, 0x3b, 0x39, 0x47, 0x3c, 0x3f, 0x39, 0x34, 0x3d, 0x3c, 0x46, + 0x3f, 0x3e, 0x3e, 0x44, 0x34, 0x40, 0x3f, 0x39, 0x3c, 0x38, 0x36, 0x45, + 0x42, 0x46, 0x3b, 0x44, 0x3a, 0x3d, 0x3b, 0x42, 0x3b, 0x3b, 0x3c, 0x45, + 0x42, 0x3d, 0x36, 0x37, 0x3d, 0x43, 0x3f, 0x48, 0xa6, 0xfb, 0xff, 0xff, + 0x04, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0xb3, 0x00, 0x00, 0x00, + 0x39, 0xff, 0xff, 0xff, 0xe5, 0xff, 0xff, 0xff, 0x30, 0x00, 0x00, 0x00, + 0x68, 0xfb, 0xff, 0xff, 0xbc, 0xfc, 0xff, 0xff, 0x20, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0xe8, 0x03, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x70, 0x02, 0x00, 0x00, + 0x70, 0x03, 0x00, 0x00, 0xf0, 0x00, 0x00, 0x00, 0xf0, 0x01, 0x00, 0x00, + 0x80, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x50, 0x01, 0x00, 0x00, + 0xa4, 0x02, 0x00, 0x00, 0xba, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03, + 0x24, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x6c, 0x61, 0x62, 0x65, + 0x6c, 0x73, 0x5f, 0x73, 0x6f, 0x66, 0x74, 0x6d, 0x61, 0x78, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x3c, 0xfd, 0xff, 0xff, 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x80, 0x3b, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7f, 0x3f, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2a, 0xfd, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x03, 0x1c, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x61, 0x64, 0x64, 0x5f, 0x31, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xa4, 0xfd, 0xff, 0xff, + 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x58, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x09, 0x97, 0xf5, 0x3f, + 0x01, 0x00, 0x00, 0x00, 0x87, 0x35, 0xa0, 0x43, 0x01, 0x00, 0x00, 0x00, + 0xd6, 0xd7, 0x28, 0xc3, 0x92, 0xfd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03, + 0x1c, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x52, 0x65, 0x6c, 0x75, + 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x19, 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x14, 0xfe, 0xff, 0xff, 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x05, 0x80, 0xbf, 0x3f, 0x01, 0x00, 0x00, 0x00, 0x85, 0xc0, 0xbe, 0x43, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xfe, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x03, 0x3c, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, 0x25, 0x00, 0x00, 0x00, + 0x77, 0x65, 0x69, 0x67, 0x68, 0x74, 0x73, 0x5f, 0x71, 0x75, 0x61, 0x6e, + 0x74, 0x2f, 0x46, 0x61, 0x6b, 0x65, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x57, + 0x69, 0x74, 0x68, 0x4d, 0x69, 0x6e, 0x4d, 0x61, 0x78, 0x56, 0x61, 0x72, + 0x73, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0xa4, 0xfe, 0xff, 0xff, 0x30, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0xae, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x17, 0xac, 0x6e, 0x3a, 0x01, 0x00, 0x00, 0x00, + 0x20, 0x4e, 0x97, 0x3d, 0x01, 0x00, 0x00, 0x00, 0xaf, 0x27, 0x21, 0xbe, + 0x96, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03, 0x20, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, 0x52, 0x65, 0x73, 0x68, 0x61, 0x70, 0x65, 0x5f, + 0x31, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x31, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x1c, 0xff, 0xff, 0xff, 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x3f, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x42, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0a, 0xff, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x02, 0x20, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x43, 0x6f, 0x6e, 0x76, 0x32, 0x44, 0x5f, 0x62, 0x69, 0x61, 0x73, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0xfc, 0xfe, 0xff, 0xff, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x17, 0xac, 0xee, 0x39, 0x5a, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03, + 0x48, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x54, 0x00, 0x00, 0x00, 0x31, 0x00, 0x00, 0x00, 0x77, 0x65, 0x69, 0x67, + 0x68, 0x74, 0x73, 0x5f, 0x71, 0x75, 0x61, 0x6e, 0x74, 0x5f, 0x31, 0x2f, + 0x46, 0x61, 0x6b, 0x65, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x57, 0x69, 0x74, + 0x68, 0x4d, 0x69, 0x6e, 0x4d, 0x61, 0x78, 0x56, 0x61, 0x72, 0x73, 0x2f, + 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x73, 0x65, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x30, 0x11, 0x00, 0x00, + 0x0c, 0x00, 0x14, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x3d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x9d, 0xaf, 0xd0, 0x3a, 0x01, 0x00, 0x00, 0x00, + 0xe7, 0x29, 0x9e, 0x3e, 0x01, 0x00, 0x00, 0x00, 0x5b, 0x91, 0xc3, 0xbd, + 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x08, 0x00, 0x07, 0x00, 0x0c, 0x00, + 0x10, 0x00, 0x14, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + 0x20, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x28, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x4d, 0x61, 0x74, 0x4d, + 0x75, 0x6c, 0x5f, 0x62, 0x69, 0x61, 0x73, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x62, 0x1b, 0x1c, 0x3b, + 0x03, 0x00, 0x00, 0x00, 0xb4, 0x00, 0x00, 0x00, 0x5c, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xc0, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x09, + 0x02, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x14, 0x00, 0x1c, 0x00, + 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x07, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x18, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, + 0x01, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x28, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, 0x14, 0x00, 0x18, 0x00, + 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x07, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x14, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + 0x1c, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x10, 0x00, + 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xfa, 0xff, 0xff, 0xff, 0x00, 0x19, 0x06, 0x00, + 0x06, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x09, 0x06, 0x00, + 0x08, 0x00, 0x07, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04}; +const int g_tiny_conv_model_data_len = 19800; diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h new file mode 100644 index 0000000000000000000000000000000000000000..2953cc852d98fa9b5551ae5036048de9c2ebf674 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h @@ -0,0 +1,27 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is a standard TensorFlow Lite model file that has been converted into a +// C data array, so it can be easily compiled into a binary for devices that +// don't have a file system. It was created using the command: +// xxd -i tiny_conv.tflite > tiny_conv_model_data.cc + +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_TINY_CONV_MODEL_DATA_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_TINY_CONV_MODEL_DATA_H_ + +extern const unsigned char g_tiny_conv_model_data[]; +extern const int g_tiny_conv_model_data_len; + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_TINY_CONV_MODEL_DATA_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/yes_features_data.cc b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/yes_features_data.cc new file mode 100644 index 0000000000000000000000000000000000000000..3ad29e53c83ddce9fcde7dae578de678d1dc75b8 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/yes_features_data.cc @@ -0,0 +1,158 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/micro/examples/micro_speech/yes_features_data.h" + +/* File automatically created by + * tensorflow/examples/speech_commands/wav_to_features.py \ + * --sample_rate=16000 \ + * --clip_duration_ms=1000 \ + * --window_size_ms=30 \ + * --window_stride_ms=20 \ + * --feature_bin_count=40 \ + * --quantize \ + * --preprocess="average" \ + * --input_wav="speech_commands_test_set_v0.02/yes/f2e59fea_nohash_1.wav" \ + * --output_c_file="yes_features_data.cc" \ + */ + +const int g_yes_f2e59fea_nohash_1_width = 43; +const int g_yes_f2e59fea_nohash_1_height = 49; +const unsigned char g_yes_f2e59fea_nohash_1_data[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 4, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 19, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 3, 3, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 8, 89, 8, 0, 0, 0, 0, 0, 0, 0, 0, 4, 13, + 1, 6, 23, 20, 6, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 19, 177, 42, 1, + 1, 0, 0, 0, 0, 2, 3, 119, 51, 5, 139, 92, 58, 58, 15, 2, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 13, 165, 176, 3, 1, 1, 0, 0, 1, 1, 32, 214, + 26, 19, 113, 103, 28, 22, 27, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 12, 55, 128, + 27, 1, 1, 0, 1, 4, 2, 52, 93, 10, 28, 156, 10, 21, 21, 3, 3, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 14, 99, 32, 65, 7, 1, 2, 2, 6, 13, 121, + 36, 15, 11, 112, 125, 14, 5, 13, 4, 4, 2, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 24, 25, + 32, 5, 1, 0, 0, 0, 1, 0, 7, 5, 1, 1, 3, 3, 0, 3, 3, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 13, 13, 5, 1, 0, 0, 0, 0, 0, 3, + 4, 1, 0, 1, 2, 3, 1, 1, 1, 4, 8, 1, 2, 1, 3, 1, 1, + 0, 1, 1, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 8, 2, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 2, 0, 2, + 1, 0, 2, 0, 2, 2, 3, 1, 1, 0, 1, 1, 4, 5, 1, 0, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 2, 1, 0, 1, 3, 1, + 1, 3, 1, 1, 6, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 2, 6, 2, 4, 2, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 0, 0, 1, 2, 1, 1, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 2, 3, 5, 2, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 2, 3, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +}; diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/yes_features_data.h b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/yes_features_data.h new file mode 100644 index 0000000000000000000000000000000000000000..33ac2308624235fc380782cd61e6a0247b81b093 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/yes_features_data.h @@ -0,0 +1,23 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_YES_FEATURES_DATA_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_YES_FEATURES_DATA_H_ + +extern const int g_yes_f2e59fea_nohash_1_width; +extern const int g_yes_f2e59fea_nohash_1_height; +extern const unsigned char g_yes_f2e59fea_nohash_1_data[]; + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_YES_FEATURES_DATA_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/BUILD b/tensorflow/contrib/lite/experimental/micro/kernels/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..a012f950e6f58f082d0a7c9ac0b4cd9018bcf40b --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/kernels/BUILD @@ -0,0 +1,107 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") +load( + "//tensorflow/contrib/lite/experimental/micro/testing:micro_test.bzl", + "tflite_micro_cc_test", +) + +cc_library( + name = "micro_ops", + srcs = [ + "depthwise_conv.cc", + "fully_connected.cc", + "softmax.cc", + ], + hdrs = [ + ], + copts = tflite_copts(), + deps = [ + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite/experimental/micro:micro_framework", + "//tensorflow/contrib/lite/kernels:kernel_util", + "//tensorflow/contrib/lite/kernels:op_macros", + "//tensorflow/contrib/lite/kernels:padding", + "//tensorflow/contrib/lite/kernels/internal:quantization_util", + "//tensorflow/contrib/lite/kernels/internal:reference_base", + "//tensorflow/contrib/lite/kernels/internal:tensor", + ], +) + +cc_library( + name = "all_ops_resolver", + srcs = [ + "all_ops_resolver.cc", + ], + hdrs = [ + "all_ops_resolver.h", + ], + copts = tflite_copts(), + deps = [ + ":micro_ops", + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite/experimental/micro:micro_framework", + ], +) + +cc_library( + name = "test_utils", + srcs = [ + ], + hdrs = [ + "test_utils.h", + ], + copts = tflite_copts(), + deps = [ + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite/core/api", + "//tensorflow/contrib/lite/experimental/micro:micro_framework", + "//tensorflow/contrib/lite/experimental/micro/testing:micro_test", + ], +) + +tflite_micro_cc_test( + name = "depthwise_conv_test", + srcs = [ + "depthwise_conv_test.cc", + ], + deps = [ + ":all_ops_resolver", + ":test_utils", + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite/experimental/micro:micro_framework", + "//tensorflow/contrib/lite/experimental/micro/testing:micro_test", + ], +) + +tflite_micro_cc_test( + name = "fully_connected_test", + srcs = [ + "fully_connected_test.cc", + ], + deps = [ + ":all_ops_resolver", + ":test_utils", + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite/experimental/micro:micro_framework", + "//tensorflow/contrib/lite/experimental/micro/testing:micro_test", + ], +) + +tflite_micro_cc_test( + name = "softmax_test", + srcs = [ + "softmax_test.cc", + ], + deps = [ + ":all_ops_resolver", + ":test_utils", + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite/experimental/micro:micro_framework", + "//tensorflow/contrib/lite/experimental/micro/testing:micro_test", + ], +) diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.cc b/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.cc new file mode 100644 index 0000000000000000000000000000000000000000..bd0a37badb8ab1e739fdee9c8be9c3f800e80e2e --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.cc @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h" + +namespace tflite { +namespace ops { +namespace micro { + +TfLiteRegistration* Register_DEPTHWISE_CONV_2D(); +TfLiteRegistration* Micro_Register_DEPTHWISE_CONV_2D() { + return Register_DEPTHWISE_CONV_2D(); +} + +TfLiteRegistration* Register_FULLY_CONNECTED(); +TfLiteRegistration* Micro_Register_FULLY_CONNECTED() { + return Register_FULLY_CONNECTED(); +} + +TfLiteRegistration* Register_SOFTMAX(); +TfLiteRegistration* Micro_Register_SOFTMAX() { return Register_SOFTMAX(); } + +AllOpsResolver::AllOpsResolver() { + AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, + Micro_Register_DEPTHWISE_CONV_2D()); + AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Micro_Register_FULLY_CONNECTED(), + /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_SOFTMAX, Micro_Register_SOFTMAX()); +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h b/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h new file mode 100644 index 0000000000000000000000000000000000000000..f836064a3f63443ff577e7ac7a8b791cbb2c24c5 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_ALL_OPS_RESOLVER_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_ALL_OPS_RESOLVER_H_ + +#include "tensorflow/contrib/lite/experimental/micro/compatibility.h" +#include "tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h" + +namespace tflite { +namespace ops { +namespace micro { + +class AllOpsResolver : public MicroMutableOpResolver { + public: + AllOpsResolver(); + + private: + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +} // namespace micro +} // namespace ops +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_ALL_OPS_RESOLVER_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv.cc new file mode 100644 index 0000000000000000000000000000000000000000..4f17263181982afdaa1941194b88d58f0ef0ca74 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv.cc @@ -0,0 +1,208 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/padding.h" + +#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace depthwise_conv { +namespace { + +constexpr int kInputTensor = 0; +constexpr int kFilterTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; + +struct OpData { + TfLitePaddingValues padding; + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; +}; + +TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, int width, + int height, int filter_width, int filter_height, + int out_width, int out_height, + const TfLiteType data_type, OpData* data) { + data->padding.height = ComputePadding(params->stride_height, 1, height, + filter_height, out_height); + data->padding.width = + ComputePadding(params->stride_width, 1, width, filter_width, out_width); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + if (data_type != kTfLiteFloat32) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* bias = + GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + int exponent; + QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); + data->output_shift = -exponent; + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + return kTfLiteOk; +} + +} // namespace + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + return nullptr; +} + +void Free(TfLiteContext* context, void* buffer) {} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +void EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, OpData* data, + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); + + tflite::DepthwiseParams op_params; + // Padding type is ignored, but still set. + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.dilation_width_factor = 1; + op_params.dilation_height_factor = 1; + op_params.depth_multiplier = params->depth_multiplier; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + + tflite::reference_ops::DepthwiseConv( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), GetTensorShape(output), + GetTensorData(output)); +} + +void EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, OpData* data, + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { + const int32_t input_offset = -input->params.zero_point; + const int32_t filter_offset = -filter->params.zero_point; + const int32_t output_offset = output->params.zero_point; + + tflite::DepthwiseParams op_params; + // Padding type is ignored, but still set. + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.dilation_width_factor = 1; + op_params.dilation_height_factor = 1; + op_params.depth_multiplier = params->depth_multiplier; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + op_params.input_offset = input_offset; + op_params.weights_offset = filter_offset; + op_params.output_offset = output_offset; + op_params.output_multiplier = data->output_multiplier; + // Legacy ops used mixed left and right shifts. Now all are +ve-means-left. + op_params.output_shift = -data->output_shift; + + tflite::reference_ops::DepthwiseConv( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output)); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* bias = + (NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr; + + const TfLiteType data_type = input->type; + int width = SizeOfDimension(input, 2); + int height = SizeOfDimension(input, 1); + int filter_width = SizeOfDimension(filter, 2); + int filter_height = SizeOfDimension(filter, 1); + int out_width = ComputeOutSize(params->padding, width, filter_width, + params->stride_width); + int out_height = ComputeOutSize(params->padding, height, filter_height, + params->stride_height); + OpData local_data_object; + OpData* data = &local_data_object; + TF_LITE_ENSURE_STATUS(CalculateOpData(context, node, params, width, height, + filter_width, filter_height, out_width, + out_height, data_type, data)); + + // TODO(aselle): Consider whether float conv and quantized conv should be + // separate ops to avoid dispatch overhead here. + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + EvalFloat(context, node, params, data, input, filter, bias, output); + break; + case kTfLiteUInt8: + EvalQuantized(context, node, params, data, input, filter, bias, output); + break; + default: + context->ReportError(context, "Type %d not currently supported.", + input->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace depthwise_conv + +TfLiteRegistration* Register_DEPTHWISE_CONV_2D() { + static TfLiteRegistration r = {depthwise_conv::Init, depthwise_conv::Free, + depthwise_conv::Prepare, depthwise_conv::Eval}; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..169899c471dd44399b4d8a479cecbbbd78ba1215 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test.cc @@ -0,0 +1,406 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h" +#include "tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h" +#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h" +#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h" + +namespace tflite { +namespace testing { +namespace { + +void TestDepthwiseConvFloat(std::initializer_list input_dims_data, + std::initializer_list input_data, + std::initializer_list filter_dims_data, + std::initializer_list filter_data, + std::initializer_list bias_dims_data, + std::initializer_list bias_data, + std::initializer_list expected_output_data, + std::initializer_list output_dims_data, + TfLiteFusedActivation activation, + float* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data); + TfLiteIntArray* filter_dims = IntArrayFromInitializer(filter_dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInitializer(bias_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + + constexpr int inputs_size = 3; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateFloatTensor(input_data, input_dims, "input_tensor"), + CreateFloatTensor(filter_data, filter_dims, "filter_tensor"), + CreateFloatTensor(bias_data, bias_dims, "bias_tensor"), + CreateFloatTensor(output_data, output_dims, "output_tensor"), + }; + + TfLiteContext context; + PopulateContext(tensors, tensors_size, &context); + + ::tflite::ops::micro::AllOpsResolver resolver; + const TfLiteRegistration* registration = + resolver.FindOp(tflite::BuiltinOperator_DEPTHWISE_CONV_2D, 1); + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + + int input_depth = input_dims->data[3]; + int output_depth = filter_dims->data[3]; + int depth_mul = output_depth / input_depth; + TfLiteDepthwiseConvParams builtin_data = { + kTfLitePaddingValid, 1, 1, depth_mul, activation, + }; + const char* init_data = reinterpret_cast(&builtin_data); + size_t init_data_size = 0; + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context, init_data, init_data_size); + } + int inputs_array_data[] = {3, 0, 1, 2}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 3}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + int temporaries_array_data[] = {0}; + TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data); + + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.temporaries = temporaries_array; + node.user_data = user_data; + node.builtin_data = reinterpret_cast(&builtin_data); + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + node.delegate = nullptr; + if (registration->prepare) { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); + } + TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + if (registration->free) { + registration->free(&context, user_data); + } + for (int i = 0; i < output_dims_count; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i], + 1e-5f); + } +} + +void TestDepthwiseConvQuantized( + std::initializer_list input_dims_data, + std::initializer_list input_data, float input_min, float input_max, + std::initializer_list filter_dims_data, + std::initializer_list filter_data, float filter_min, + float filter_max, std::initializer_list bias_dims_data, + std::initializer_list bias_data, float bias_min, float bias_max, + std::initializer_list expected_output_data, + std::initializer_list output_dims_data, float output_min, + float output_max, TfLiteFusedActivation activation, uint8_t* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data); + TfLiteIntArray* filter_dims = IntArrayFromInitializer(filter_dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInitializer(bias_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + + constexpr int inputs_size = 3; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateQuantizedTensor(input_data, input_dims, "input_tensor", input_min, + input_max), + CreateQuantizedTensor(filter_data, filter_dims, "filter_tensor", + filter_min, filter_max), + CreateQuantized32Tensor(bias_data, bias_dims, "bias_tensor", bias_min, + bias_max), + CreateQuantizedTensor(output_data, output_dims, "output_tensor", + output_min, output_max), + }; + + TfLiteContext context; + PopulateContext(tensors, tensors_size, &context); + + ::tflite::ops::micro::AllOpsResolver resolver; + const TfLiteRegistration* registration = + resolver.FindOp(tflite::BuiltinOperator_DEPTHWISE_CONV_2D, 1); + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + + int input_depth = input_dims->data[3]; + int output_depth = filter_dims->data[3]; + int depth_mul = output_depth / input_depth; + TfLiteDepthwiseConvParams builtin_data = { + kTfLitePaddingValid, 1, 1, depth_mul, activation, + }; + const char* init_data = reinterpret_cast(&builtin_data); + size_t init_data_size = 0; + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context, init_data, init_data_size); + } + + int inputs_array_data[] = {3, 0, 1, 2}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 3}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + int temporaries_array_data[] = {0}; + TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data); + + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.temporaries = temporaries_array; + node.user_data = user_data; + node.builtin_data = reinterpret_cast(&builtin_data); + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + node.delegate = nullptr; + + if (registration->prepare) { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); + } + TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + if (registration->free) { + registration->free(&context, user_data); + } + for (int i = 0; i < output_dims_count; ++i) { + TF_LITE_MICRO_EXPECT_EQ(expected_output_data.begin()[i], output_data[i]); + } +} + +} // namespace +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(SimpleTest) { + const int output_dims_count = 8; + float output_data[output_dims_count]; + tflite::testing::TestDepthwiseConvFloat( // + {4, 1, 3, 2, 2}, // Input shape. + { + 1, 2, 7, 8, // Input values. + 3, 4, 9, 10, // + 5, 6, 11, 12, // + }, + {4, 1, 2, 2, 4}, // Filters shape. + { + 1, 2, 3, 4, // Filters values. + -9, 10, -11, 12, // + 5, 6, 7, 8, // + 13, -14, 15, -16, // + }, + {1, 4}, // Bias shape. + { + 1, 2, 3, 4, // Bias values. + }, + { + 71, -34, 99, -20, // Expected results. + 91, -26, 127, -4, // + }, + {4, 1, 2, 1, 4}, // Output shape. + kTfLiteActNone, output_data); +} + +TF_LITE_MICRO_TEST(SimpleTestQuantized) { + using tflite::testing::F2Q; + using tflite::testing::F2Q32; + + const float input_min = -63.5f; + const float input_max = 64.0f; + const float filter_min = -63.5f; + const float filter_max = 64.0f; + const float bias_min = 0.0f; + const float bias_max = 64.0f * (1 << 24); + const float output_min = -127.0f; + const float output_max = 128.0f; + const int output_dims_count = 8; + uint8_t output_data[output_dims_count]; + + tflite::testing::TestDepthwiseConvQuantized( // + {4, 1, 3, 2, 2}, // Input shape. + { + // Input values. + F2Q(1, input_min, input_max), + F2Q(2, input_min, input_max), + F2Q(7, input_min, input_max), + F2Q(8, input_min, input_max), + F2Q(3, input_min, input_max), + F2Q(4, input_min, input_max), + F2Q(9, input_min, input_max), + F2Q(10, input_min, input_max), + F2Q(5, input_min, input_max), + F2Q(6, input_min, input_max), + F2Q(11, input_min, input_max), + F2Q(12, input_min, input_max), + }, + input_min, input_max, // Input quantization range. + {4, 1, 2, 2, 4}, // Filter shape. + { + // Filter values. + F2Q(1, filter_min, filter_max), + F2Q(2, filter_min, filter_max), + F2Q(3, filter_min, filter_max), + F2Q(4, filter_min, filter_max), + F2Q(-9, filter_min, filter_max), + F2Q(10, filter_min, filter_max), + F2Q(-11, filter_min, filter_max), + F2Q(12, filter_min, filter_max), + F2Q(5, filter_min, filter_max), + F2Q(6, filter_min, filter_max), + F2Q(7, filter_min, filter_max), + F2Q(8, filter_min, filter_max), + F2Q(13, filter_min, filter_max), + F2Q(-14, filter_min, filter_max), + F2Q(15, filter_min, filter_max), + F2Q(-16, filter_min, filter_max), + }, + filter_min, filter_max, // Filter quantization range. + {1, 4}, // Bias shape. + { + // Bias values. + F2Q32(1, bias_min, bias_max), + F2Q32(2, bias_min, bias_max), + F2Q32(3, bias_min, bias_max), + F2Q32(4, bias_min, bias_max), + }, + bias_min, bias_max, // Bias quantization range. + { + // Expected results. + F2Q(71, output_min, output_max), + F2Q(-34, output_min, output_max), + F2Q(99, output_min, output_max), + F2Q(-20, output_min, output_max), + F2Q(91, output_min, output_max), + F2Q(-26, output_min, output_max), + F2Q(127, output_min, output_max), + F2Q(-4, output_min, output_max), + }, + {4, 1, 2, 1, 4}, // Output shape. + output_min, output_max, // Output quantization range. + kTfLiteActNone, output_data); +} + +TF_LITE_MICRO_TEST(SimpleTestRelu) { + const int output_dims_count = 8; + float output_data[output_dims_count]; + tflite::testing::TestDepthwiseConvFloat( // + {4, 1, 3, 2, 2}, // Input shape. + { + 1, 2, 7, 8, // Input values. + 3, 4, 9, 10, // + 5, 6, 11, 12, // + }, + {4, 1, 2, 2, 4}, // Filters shape. + { + 1, 2, 3, 4, // Filters values. + -9, 10, -11, 12, // + 5, 6, 7, 8, // + 13, -14, 15, -16, // + }, + {1, 4}, // Bias shape. + { + 1, 2, 3, 4, // Bias values. + }, + { + 71, 0, 99, 0, // Expected results. + 91, 0, 127, 0, // + }, + {4, 1, 2, 1, 4}, // Output shape. + kTfLiteActRelu, output_data); +} + +TF_LITE_MICRO_TEST(SimpleTestReluQuantized) { + using tflite::testing::F2Q; + using tflite::testing::F2Q32; + + const float input_min = -63.5f; + const float input_max = 64.0f; + const float filter_min = -63.5f; + const float filter_max = 64.0f; + const float bias_min = 0.0f; + const float bias_max = 64.0f * (1 << 24); + const float output_min = -127.0f; + const float output_max = 128.0f; + const int output_dims_count = 8; + uint8_t output_data[output_dims_count]; + + tflite::testing::TestDepthwiseConvQuantized( // + {4, 1, 3, 2, 2}, // Input shape. + { + // Input values. + F2Q(1, input_min, input_max), + F2Q(2, input_min, input_max), + F2Q(7, input_min, input_max), + F2Q(8, input_min, input_max), + F2Q(3, input_min, input_max), + F2Q(4, input_min, input_max), + F2Q(9, input_min, input_max), + F2Q(10, input_min, input_max), + F2Q(5, input_min, input_max), + F2Q(6, input_min, input_max), + F2Q(11, input_min, input_max), + F2Q(12, input_min, input_max), + }, + input_min, input_max, // Input quantization range. + {4, 1, 2, 2, 4}, // Filter shape. + { + // Filter values. + F2Q(1, filter_min, filter_max), + F2Q(2, filter_min, filter_max), + F2Q(3, filter_min, filter_max), + F2Q(4, filter_min, filter_max), + F2Q(-9, filter_min, filter_max), + F2Q(10, filter_min, filter_max), + F2Q(-11, filter_min, filter_max), + F2Q(12, filter_min, filter_max), + F2Q(5, filter_min, filter_max), + F2Q(6, filter_min, filter_max), + F2Q(7, filter_min, filter_max), + F2Q(8, filter_min, filter_max), + F2Q(13, filter_min, filter_max), + F2Q(-14, filter_min, filter_max), + F2Q(15, filter_min, filter_max), + F2Q(-16, filter_min, filter_max), + }, + filter_min, filter_max, // Filter quantization range. + {1, 4}, // Bias shape. + { + // Bias values. + F2Q32(1, bias_min, bias_max), + F2Q32(2, bias_min, bias_max), + F2Q32(3, bias_min, bias_max), + F2Q32(4, bias_min, bias_max), + }, + bias_min, bias_max, // Bias quantization range. + { + // Expected results. + F2Q(71, output_min, output_max), + F2Q(0, output_min, output_max), + F2Q(99, output_min, output_max), + F2Q(0, output_min, output_max), + F2Q(91, output_min, output_max), + F2Q(0, output_min, output_max), + F2Q(127, output_min, output_max), + F2Q(0, output_min, output_max), + }, + {4, 1, 2, 1, 4}, // Output shape. + output_min, output_max, // Output quantization range. + kTfLiteActRelu, output_data); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected.cc b/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected.cc new file mode 100644 index 0000000000000000000000000000000000000000..1e9e54cafb8c91af1b42d6d23396495ecad6e602 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected.cc @@ -0,0 +1,184 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace fully_connected { +namespace { + +struct OpData { + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; + // The index of the temporary tensor where the quantized inputs are cached. + int input_quantized_index; +}; + +constexpr int kInputTensor = 0; +constexpr int kWeightsTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; + +TfLiteStatus CalculateOpData(TfLiteContext* context, + TfLiteFullyConnectedParams* params, + TfLiteType data_type, const TfLiteTensor* input, + const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output, + OpData* data) { + TfLiteStatus status = kTfLiteOk; + if (data_type != kTfLiteFloat32) { + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + int exponent; + QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); + data->output_shift = -exponent; + TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized( + context, params->activation, output, &data->output_activation_min, + &data->output_activation_max)); + } + return status; +} + +} // namespace + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + return nullptr; +} + +void Free(TfLiteContext* context, void* buffer) {} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + const TfLiteTensor* input, + const TfLiteTensor* filter, const TfLiteTensor* bias, + TfLiteTensor* output) { + const int32_t input_offset = -input->params.zero_point; + const int32_t filter_offset = -filter->params.zero_point; + const int32_t output_offset = output->params.zero_point; + + tflite::FullyConnectedParams op_params; + op_params.input_offset = input_offset; + op_params.weights_offset = filter_offset; + op_params.output_offset = output_offset; + op_params.output_multiplier = data->output_multiplier; + // Legacy ops used mixed left and right shifts. Now all are +ve-means-left. + op_params.output_shift = -data->output_shift; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + +#define TF_LITE_FULLY_CONNECTED(output_data_type) \ + reference_ops::FullyConnected( \ + op_params, GetTensorShape(input), GetTensorData(input), \ + GetTensorShape(filter), GetTensorData(filter), \ + GetTensorShape(bias), GetTensorData(bias), \ + GetTensorShape(output), GetTensorData(output), \ + nullptr) + switch (output->type) { + case kTfLiteUInt8: + TF_LITE_FULLY_CONNECTED(uint8_t); + break; + case kTfLiteInt16: + TF_LITE_FULLY_CONNECTED(int16_t); + break; + default: + context->ReportError( + context, + "Quantized FullyConnected expects output data type uint8 or int16"); + return kTfLiteError; + } + + return kTfLiteOk; +} + +TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); + tflite::FullyConnectedParams op_params; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + tflite::reference_ops::FullyConnected( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), GetTensorShape(output), + GetTensorData(output)); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TfLiteType data_type = input->type; + OpData local_data_object; + OpData* data = &local_data_object; + TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input, + filter, bias, output, data)); + + switch (filter->type) { // Already know in/out types are same. + case kTfLiteFloat32: + return EvalFloat(context, node, params, data, input, filter, bias, + output); + case kTfLiteUInt8: + return EvalQuantized(context, node, params, data, input, filter, bias, + output); + + default: + context->ReportError(context, "Type %d not currently supported.", + filter->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace fully_connected + +TfLiteRegistration* Register_FULLY_CONNECTED() { + static TfLiteRegistration r = {fully_connected::Init, fully_connected::Free, + fully_connected::Prepare, + fully_connected::Eval}; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected_test.cc b/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b42bf4c3bca75572dbf8e1907e7fb94be24d41bd --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected_test.cc @@ -0,0 +1,643 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h" +#include "tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h" +#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h" +#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h" + +namespace tflite { +namespace testing { +namespace { + +void TestFullyConnectedFloat(std::initializer_list input_dims_data, + std::initializer_list input_data, + std::initializer_list weights_dims_data, + std::initializer_list weights_data, + std::initializer_list bias_dims_data, + std::initializer_list bias_data, + std::initializer_list expected_output_data, + std::initializer_list output_dims_data, + TfLiteFusedActivation activation, + float* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data); + TfLiteIntArray* weights_dims = IntArrayFromInitializer(weights_dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInitializer(bias_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + + constexpr int inputs_size = 3; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateFloatTensor(input_data, input_dims, "input_tensor"), + CreateFloatTensor(weights_data, weights_dims, "weights_tensor"), + CreateFloatTensor(bias_data, bias_dims, "bias_tensor"), + CreateFloatTensor(output_data, output_dims, "output_tensor"), + }; + + TfLiteContext context; + PopulateContext(tensors, tensors_size, &context); + + ::tflite::ops::micro::AllOpsResolver resolver; + const TfLiteRegistration* registration = + resolver.FindOp(tflite::BuiltinOperator_FULLY_CONNECTED, 1); + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + + TfLiteFullyConnectedParams builtin_data = { + activation, + kTfLiteFullyConnectedWeightsFormatDefault, + }; + const char* init_data = reinterpret_cast(&builtin_data); + size_t init_data_size = 0; + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context, init_data, init_data_size); + } + int inputs_array_data[] = {3, 0, 1, 2}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 3}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + int temporaries_array_data[] = {0}; + TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data); + + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.temporaries = temporaries_array; + node.user_data = user_data; + node.builtin_data = reinterpret_cast(&builtin_data); + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + node.delegate = nullptr; + if (registration->prepare) { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); + } + TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + if (registration->free) { + registration->free(&context, user_data); + } + for (int i = 0; i < output_dims_count; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i], + 1e-5f); + } +} + +void TestFullyConnectedQuantized( + std::initializer_list input_dims_data, + std::initializer_list input_data, float input_min, float input_max, + std::initializer_list weights_dims_data, + std::initializer_list weights_data, float weights_min, + float weights_max, std::initializer_list bias_dims_data, + std::initializer_list bias_data, float bias_min, float bias_max, + std::initializer_list expected_output_data, + std::initializer_list output_dims_data, float output_min, + float output_max, TfLiteFusedActivation activation, uint8_t* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data); + TfLiteIntArray* weights_dims = IntArrayFromInitializer(weights_dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInitializer(bias_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + + constexpr int inputs_size = 3; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateQuantizedTensor(input_data, input_dims, "input_tensor", input_min, + input_max), + CreateQuantizedTensor(weights_data, weights_dims, "weights_tensor", + weights_min, weights_max), + CreateQuantized32Tensor(bias_data, bias_dims, "bias_tensor", bias_min, + bias_max), + CreateQuantizedTensor(output_data, output_dims, "output_tensor", + output_min, output_max), + }; + + TfLiteContext context; + PopulateContext(tensors, tensors_size, &context); + + ::tflite::ops::micro::AllOpsResolver resolver; + const TfLiteRegistration* registration = + resolver.FindOp(tflite::BuiltinOperator_FULLY_CONNECTED, 1); + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + + TfLiteFullyConnectedParams builtin_data = { + activation, + kTfLiteFullyConnectedWeightsFormatDefault, + }; + const char* init_data = reinterpret_cast(&builtin_data); + size_t init_data_size = 0; + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context, init_data, init_data_size); + } + + int inputs_array_data[] = {3, 0, 1, 2}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 3}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + int temporaries_array_data[] = {0}; + TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data); + + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.temporaries = temporaries_array; + node.user_data = user_data; + node.builtin_data = reinterpret_cast(&builtin_data); + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + node.delegate = nullptr; + + if (registration->prepare) { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); + } + TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + if (registration->free) { + registration->free(&context, user_data); + } + for (int i = 0; i < output_dims_count; ++i) { + TF_LITE_MICRO_EXPECT_EQ(expected_output_data.begin()[i], output_data[i]); + } +} + +} // namespace +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(SimpleTest) { + const int output_dims_count = 6; + float output_data[output_dims_count]; + tflite::testing::TestFullyConnectedFloat( // + {2, 2, 10}, // Input shape. + { + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }, + {2, 3, 10}, // Weights shape. + { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 + }, + {1, 3}, // Bias shape. + { + 1, 2, 3, // Bias values. + }, + { + 24, 25, 26, 58, 59, 60, // Expected results. + }, + {2, 2, 3}, // Output shape. + kTfLiteActNone, output_data); +} + +TF_LITE_MICRO_TEST(SimpleTest2) { + const int output_dims_count = 6; + float output_data[output_dims_count]; + tflite::testing::TestFullyConnectedFloat( // + {2, 2, 2}, // Input shape. + { + 1, 2, // b = 0 + 2, 1, // b = 1 + }, + {2, 1, 2}, // Weights shape. + { + 2, 4, // u = 0 + }, + {1, 1}, // Bias shape. + { + 1, // Bias values. + }, + { + 11, 9, // Expected results. + }, + {2, 2, 1}, // Output shape. + kTfLiteActNone, output_data); +} + +TF_LITE_MICRO_TEST(SimpleTestRelu) { + const int output_dims_count = 6; + float output_data[output_dims_count]; + tflite::testing::TestFullyConnectedFloat( // + {2, 2, 10}, // Input shape. + { + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }, + {2, 3, 10}, // Weights shape. + { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 + }, + {1, 3}, // Bias shape. + { + 1, -2, 3, // Bias values. + }, + { + 24, 0, 26, 58, 0, 60, // Expected results. + }, + {2, 2, 3}, // Output shape. + kTfLiteActRelu, output_data); +} + +TF_LITE_MICRO_TEST(SimpleTestQuantized) { + using tflite::testing::F2Q; + using tflite::testing::F2Q32; + + const float input_min = -63.5f; + const float input_max = 64.0f; + const float weights_min = -63.5f; + const float weights_max = 64.0f; + const float bias_min = 0.0f; + const float bias_max = 64.0f * (1 << 24); + const float output_min = -127.0f; + const float output_max = 128.0f; + const int output_dims_count = 6; + uint8_t output_data[output_dims_count]; + tflite::testing::TestFullyConnectedQuantized( // + {2, 2, 10}, // Input shape. + { + // Input values. + F2Q(1, input_min, input_max), F2Q(2, input_min, input_max), + F2Q(3, input_min, input_max), F2Q(4, input_min, input_max), + F2Q(5, input_min, input_max), F2Q(6, input_min, input_max), + F2Q(7, input_min, input_max), F2Q(8, input_min, input_max), + F2Q(-9, input_min, input_max), F2Q(-10, input_min, input_max), + F2Q(1, input_min, input_max), F2Q(2, input_min, input_max), + F2Q(3, input_min, input_max), F2Q(4, input_min, input_max), + F2Q(5, input_min, input_max), F2Q(6, input_min, input_max), + F2Q(7, input_min, input_max), F2Q(-8, input_min, input_max), + F2Q(9, input_min, input_max), F2Q(-10, input_min, input_max), + }, + input_min, input_max, // Input quantization range. + {2, 3, 10}, // Weights shape. + { + // Weight values. + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + }, + weights_min, weights_max, // Weights quantization range. + {1, 3}, // Bias shape. + { + F2Q32(1, bias_min, bias_max), + F2Q32(2, bias_min, bias_max), + F2Q32(3, bias_min, bias_max), + }, + bias_min, bias_max, // Bias quantization range. + { + // Expected results. + F2Q(24, output_min, output_max), + F2Q(25, output_min, output_max), + F2Q(26, output_min, output_max), + F2Q(58, output_min, output_max), + F2Q(59, output_min, output_max), + F2Q(60, output_min, output_max), + }, + {2, 2, 3}, // Output shape. + output_min, output_max, // Output quantization range. + kTfLiteActNone, output_data); +} + +TF_LITE_MICRO_TEST(SimpleTestQuantizedRelu) { + using tflite::testing::F2Q; + using tflite::testing::F2Q32; + + const float input_min = -63.5f; + const float input_max = 64.0f; + const float weights_min = -63.5f; + const float weights_max = 64.0f; + const float bias_min = 0.0f; + const float bias_max = 64.0f * (1 << 24); + const float output_min = -127.0f; + const float output_max = 128.0f; + const int output_dims_count = 6; + uint8_t output_data[output_dims_count]; + tflite::testing::TestFullyConnectedQuantized( // + {2, 2, 10}, // Input shape. + { + // Input values. + F2Q(1, input_min, input_max), F2Q(2, input_min, input_max), + F2Q(3, input_min, input_max), F2Q(4, input_min, input_max), + F2Q(5, input_min, input_max), F2Q(6, input_min, input_max), + F2Q(7, input_min, input_max), F2Q(8, input_min, input_max), + F2Q(-9, input_min, input_max), F2Q(-10, input_min, input_max), + F2Q(1, input_min, input_max), F2Q(2, input_min, input_max), + F2Q(3, input_min, input_max), F2Q(4, input_min, input_max), + F2Q(5, input_min, input_max), F2Q(6, input_min, input_max), + F2Q(7, input_min, input_max), F2Q(-8, input_min, input_max), + F2Q(9, input_min, input_max), F2Q(-10, input_min, input_max), + }, + input_min, input_max, // Input quantization range. + {2, 3, 10}, // Weights shape. + { + // Weight values. + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + F2Q(-1, weights_min, weights_max), F2Q(-2, weights_min, weights_max), + F2Q(-3, weights_min, weights_max), F2Q(-4, weights_min, weights_max), + F2Q(-5, weights_min, weights_max), F2Q(-6, weights_min, weights_max), + F2Q(-7, weights_min, weights_max), F2Q(-8, weights_min, weights_max), + F2Q(-9, weights_min, weights_max), F2Q(-10, weights_min, weights_max), + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + }, + weights_min, weights_max, // Weights quantization range. + {1, 3}, // Bias shape. + { + F2Q32(1, bias_min, bias_max), + F2Q32(0, bias_min, bias_max), + F2Q32(3, bias_min, bias_max), + }, + bias_min, bias_max, // Bias quantization range. + { + // Expected results. + F2Q(24, output_min, output_max), + F2Q(0, output_min, output_max), + F2Q(26, output_min, output_max), + F2Q(58, output_min, output_max), + F2Q(0, output_min, output_max), + F2Q(60, output_min, output_max), + }, + {2, 2, 3}, // Output shape. + output_min, output_max, // Output quantization range. + kTfLiteActRelu, output_data); +} + +TF_LITE_MICRO_TEST(SimpleTestQuantizedOutputMultiplierGreaterThan1) { + using tflite::testing::F2Q; + using tflite::testing::F2Q32; + + const float input_min = -127.0f; + const float input_max = 128.0f; + const float weights_min = -127.0f; + const float weights_max = 128.0f; + const float bias_min = 0.0f; + const float bias_max = 256.0f * (1 << 24); + const float output_min = -63.5f; + const float output_max = 64.0f; + const int output_dims_count = 6; + uint8_t output_data[output_dims_count]; + tflite::testing::TestFullyConnectedQuantized( // + {2, 2, 10}, // Input shape. + { + // Input values. + F2Q(1, input_min, input_max), F2Q(2, input_min, input_max), + F2Q(3, input_min, input_max), F2Q(4, input_min, input_max), + F2Q(5, input_min, input_max), F2Q(6, input_min, input_max), + F2Q(7, input_min, input_max), F2Q(8, input_min, input_max), + F2Q(-9, input_min, input_max), F2Q(-10, input_min, input_max), + F2Q(1, input_min, input_max), F2Q(2, input_min, input_max), + F2Q(3, input_min, input_max), F2Q(4, input_min, input_max), + F2Q(5, input_min, input_max), F2Q(6, input_min, input_max), + F2Q(7, input_min, input_max), F2Q(-8, input_min, input_max), + F2Q(9, input_min, input_max), F2Q(-10, input_min, input_max), + }, + input_min, input_max, // Input quantization range. + {2, 3, 10}, // Weights shape. + { + // Weight values. + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + }, + weights_min, weights_max, // Weights quantization range. + {1, 3}, // Bias shape. + { + F2Q32(1, bias_min, bias_max), + F2Q32(2, bias_min, bias_max), + F2Q32(3, bias_min, bias_max), + }, + bias_min, bias_max, // Bias quantization range. + { + // Expected results. + F2Q(24, output_min, output_max), + F2Q(25, output_min, output_max), + F2Q(26, output_min, output_max), + F2Q(58, output_min, output_max), + F2Q(59, output_min, output_max), + F2Q(60, output_min, output_max), + }, + {2, 2, 3}, // Output shape. + output_min, output_max, // Output quantization range. + kTfLiteActNone, output_data); +} + +TF_LITE_MICRO_TEST(SimpleTest4DInput) { + const int output_dims_count = 6; + float output_data[output_dims_count]; + tflite::testing::TestFullyConnectedFloat( // + {4, 1, 1, 5, 1}, // Input shape. + { + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }, + {2, 3, 10}, // Weights shape. + { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 + }, + {1, 3}, // Bias shape. + { + 1, 2, 3, // Bias values. + }, + { + 24, 25, 26, 58, 59, 60, // Expected results. + }, + {2, 2, 3}, // Output shape. + kTfLiteActNone, output_data); +} + +TF_LITE_MICRO_TEST(SimpleTest4DInputQuantized) { + using tflite::testing::F2Q; + using tflite::testing::F2Q32; + + const float input_min = -63.5f; + const float input_max = 64.0f; + const float weights_min = -63.5f; + const float weights_max = 64.0f; + const float bias_min = 0.0f; + const float bias_max = 64.0f * (1 << 24); + const float output_min = -127.0f; + const float output_max = 128.0f; + const int output_dims_count = 6; + uint8_t output_data[output_dims_count]; + tflite::testing::TestFullyConnectedQuantized( // + {4, 1, 1, 5, 1}, // Input shape. + { + // Input values. + F2Q(1, input_min, input_max), F2Q(2, input_min, input_max), + F2Q(3, input_min, input_max), F2Q(4, input_min, input_max), + F2Q(5, input_min, input_max), F2Q(6, input_min, input_max), + F2Q(7, input_min, input_max), F2Q(8, input_min, input_max), + F2Q(-9, input_min, input_max), F2Q(-10, input_min, input_max), + F2Q(1, input_min, input_max), F2Q(2, input_min, input_max), + F2Q(3, input_min, input_max), F2Q(4, input_min, input_max), + F2Q(5, input_min, input_max), F2Q(6, input_min, input_max), + F2Q(7, input_min, input_max), F2Q(-8, input_min, input_max), + F2Q(9, input_min, input_max), F2Q(-10, input_min, input_max), + }, + input_min, input_max, // Input quantization range. + {2, 3, 10}, // Weights shape. + { + // Weight values. + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + }, + weights_min, weights_max, // Weights quantization range. + {1, 3}, // Bias shape. + { + F2Q32(1, bias_min, bias_max), + F2Q32(2, bias_min, bias_max), + F2Q32(3, bias_min, bias_max), + }, + bias_min, bias_max, // Bias quantization range. + { + // Expected results. + F2Q(24, output_min, output_max), + F2Q(25, output_min, output_max), + F2Q(26, output_min, output_max), + F2Q(58, output_min, output_max), + F2Q(59, output_min, output_max), + F2Q(60, output_min, output_max), + }, + {2, 2, 3}, // Output shape. + output_min, output_max, // Output quantization range. + kTfLiteActNone, output_data); +} + +TF_LITE_MICRO_TEST(SimpleTest4DInputQuantizedOutputMultiplierGreaterThan1) { + using tflite::testing::F2Q; + using tflite::testing::F2Q32; + + const float input_min = -127.0f; + const float input_max = 128.0f; + const float weights_min = -127.0f; + const float weights_max = 128.0f; + const float bias_min = 0.0f; + const float bias_max = 256.0f * (1 << 24); + const float output_min = -63.5f; + const float output_max = 64.0f; + const int output_dims_count = 6; + uint8_t output_data[output_dims_count]; + tflite::testing::TestFullyConnectedQuantized( // + {4, 1, 1, 5, 1}, // Input shape. + { + // Input values. + F2Q(1, input_min, input_max), F2Q(2, input_min, input_max), + F2Q(3, input_min, input_max), F2Q(4, input_min, input_max), + F2Q(5, input_min, input_max), F2Q(6, input_min, input_max), + F2Q(7, input_min, input_max), F2Q(8, input_min, input_max), + F2Q(-9, input_min, input_max), F2Q(-10, input_min, input_max), + F2Q(1, input_min, input_max), F2Q(2, input_min, input_max), + F2Q(3, input_min, input_max), F2Q(4, input_min, input_max), + F2Q(5, input_min, input_max), F2Q(6, input_min, input_max), + F2Q(7, input_min, input_max), F2Q(-8, input_min, input_max), + F2Q(9, input_min, input_max), F2Q(-10, input_min, input_max), + }, + input_min, input_max, // Input quantization range. + {2, 3, 10}, // Weights shape. + { + // Weight values. + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + }, + weights_min, weights_max, // Weights quantization range. + {1, 3}, // Bias shape. + { + F2Q32(1, bias_min, bias_max), + F2Q32(2, bias_min, bias_max), + F2Q32(3, bias_min, bias_max), + }, + bias_min, bias_max, // Bias quantization range. + { + // Expected results. + F2Q(24, output_min, output_max), + F2Q(25, output_min, output_max), + F2Q(26, output_min, output_max), + F2Q(58, output_min, output_max), + F2Q(59, output_min, output_max), + F2Q(60, output_min, output_max), + }, + {2, 2, 3}, // Output shape. + output_min, output_max, // Output quantization range. + kTfLiteActNone, output_data); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/softmax.cc b/tensorflow/contrib/lite/experimental/micro/kernels/softmax.cc new file mode 100644 index 0000000000000000000000000000000000000000..a4019a067c563cac25d9918e4bdf75913bdfa3d6 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/kernels/softmax.cc @@ -0,0 +1,213 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/kernels/internal/reference/softmax.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace activations { +namespace { + +struct OpData { + int32_t input_multiplier = 0; + int input_left_shift = 0; + int32_t input_range_radius = 0; + int diff_min = 0; +}; + +TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context, + const TfLiteTensor* input, + TfLiteTensor* output, + const TfLiteSoftmaxParams* params, + OpData* data) { + if (input->type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + TF_LITE_ENSURE(context, output->params.scale == 1. / 256); + + static const int kScaledDiffIntegerBits = 5; + + tflite::PreprocessSoftmaxScaling( + params->beta, input->params.scale, kScaledDiffIntegerBits, + &data->input_multiplier, &data->input_left_shift); + data->diff_min = -1.0 * tflite::CalculateInputRadius( + kScaledDiffIntegerBits, data->input_left_shift); + } + return kTfLiteOk; +} + +} // namespace + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + return nullptr; +} + +void Free(TfLiteContext* context, void* buffer) {} + +TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +// Takes a 1D tensor and performs softmax along it. +void Softmax1DFloat(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params) { + const int input_size = input->dims->data[0]; + tflite::reference_ops::Softmax(input->data.f, input_size, 1, params->beta, + output->data.f); +} + +// Takes a 2D tensor and perform softmax along the last dimension. +void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params) { + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + tflite::reference_ops::Softmax(input->data.f, input_size, batch_size, + params->beta, output->data.f); +} + +void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params, OpData* data) { + // TODO(ahentz): this is arguably a dirty trick. Since the implementation + // always traverses the last dimension of a 4D tensor, we will pretend our 1D + // tensor is 4D in a special way. We will convert a (Y) shape into a (1, + // 1, 1, Y) shape. + const int input_size = input->dims->data[0]; + const int32_t shape_data[4] = {1, 1, 1, input_size}; + RuntimeShape shape(4, shape_data); + SoftmaxParams op_params; + op_params.input_multiplier = data->input_multiplier; + op_params.input_left_shift = data->input_left_shift; + op_params.diff_min = data->diff_min; + tflite::reference_ops::Softmax(op_params, shape, + GetTensorData(input), shape, + GetTensorData(output)); +} + +void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params, OpData* data) { + // TODO(ahentz): this is arguably a dirty trick. Since the implementation + // always traverses the last dimension of a 4D tensor, we will pretend our 2D + // tensor is 4D in a special way. We will convert a (X, Y) shape into a (X, + // 1, 1, Y) shape. + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + const int32_t shape_data[4] = {batch_size, 1, 1, input_size}; + RuntimeShape shape(4, shape_data); + SoftmaxParams op_params; + op_params.input_multiplier = data->input_multiplier; + op_params.input_left_shift = data->input_left_shift; + op_params.diff_min = data->diff_min; + tflite::reference_ops::Softmax(op_params, shape, + GetTensorData(input), shape, + GetTensorData(output)); +} + +// Takes a 4D tensor and perform softmax along the forth dimension. +void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params) { + SoftmaxParams op_params; + op_params.beta = params->beta; + tflite::reference_ops::Softmax( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); +} + +void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params, OpData* data) { + SoftmaxParams op_params; + op_params.input_multiplier = data->input_multiplier; + op_params.input_left_shift = data->input_left_shift; + op_params.diff_min = data->diff_min; + tflite::reference_ops::Softmax( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); +} + +TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + const TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + + OpData local_data_object; + OpData* data = &local_data_object; + TF_LITE_ENSURE_STATUS( + CalculateSoftmaxOpData(context, input, output, params, data)); + + // TODO(ahentz): consider an implementation that works for many (all?) + // dimensions. + switch (input->type) { + case kTfLiteFloat32: { + if (NumDimensions(input) == 1) { + Softmax1DFloat(input, output, params); + return kTfLiteOk; + } + if (NumDimensions(input) == 2) { + Softmax2DFloat(input, output, params); + return kTfLiteOk; + } + if (NumDimensions(input) == 4) { + Softmax4DFloat(input, output, params); + return kTfLiteOk; + } + context->ReportError( + context, "Only 1D, 2D and 4D tensors supported currently, got %dD.", + NumDimensions(input)); + return kTfLiteError; + } + case kTfLiteUInt8: { + if (NumDimensions(input) == 1) { + Softmax1DQuantized(input, output, params, data); + return kTfLiteOk; + } + if (NumDimensions(input) == 2) { + Softmax2DQuantized(input, output, params, data); + return kTfLiteOk; + } + if (NumDimensions(input) == 4) { + Softmax4DQuantized(input, output, params, data); + return kTfLiteOk; + } + context->ReportError( + context, "Only 2D and 4D tensors supported currently, got %dD.", + NumDimensions(input)); + return kTfLiteError; + } + default: + context->ReportError( + context, "Only float32 and uint8_t supported currently, got %d.", + input->type); + return kTfLiteError; + } +} +} // namespace activations + +TfLiteRegistration* Register_SOFTMAX() { + static TfLiteRegistration r = {activations::Init, activations::Free, + activations::SoftmaxPrepare, + activations::SoftmaxEval}; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/softmax_test.cc b/tensorflow/contrib/lite/experimental/micro/kernels/softmax_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..694456d8ace5182578f9b59c2de8bbad0447b4ee --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/kernels/softmax_test.cc @@ -0,0 +1,220 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h" +#include "tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h" +#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h" +#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h" + +namespace tflite { +namespace testing { +namespace { + +void TestSoftmaxFloat(std::initializer_list input_dims_data, + std::initializer_list input_data, + std::initializer_list expected_output_data, + std::initializer_list output_dims_data, + float* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + + constexpr int inputs_size = 2; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateFloatTensor(input_data, input_dims, "input_tensor"), + CreateFloatTensor(output_data, output_dims, "output_tensor"), + }; + + TfLiteContext context; + PopulateContext(tensors, tensors_size, &context); + + ::tflite::ops::micro::AllOpsResolver resolver; + const TfLiteRegistration* registration = + resolver.FindOp(tflite::BuiltinOperator_SOFTMAX, 1); + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + + TfLiteSoftmaxParams builtin_data = {1.0f}; + const char* init_data = reinterpret_cast(&builtin_data); + size_t init_data_size = 0; + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context, init_data, init_data_size); + } + int inputs_array_data[] = {1, 0}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 1}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + int temporaries_array_data[] = {0}; + TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data); + + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.temporaries = temporaries_array; + node.user_data = user_data; + node.builtin_data = reinterpret_cast(&builtin_data); + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + node.delegate = nullptr; + if (registration->prepare) { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); + } + TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + if (registration->free) { + registration->free(&context, user_data); + } + for (int i = 0; i < output_dims_count; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i], + 1e-5f); + } +} + +void TestSoftmaxQuantized(std::initializer_list input_dims_data, + std::initializer_list input_data, + float input_min, float input_max, + std::initializer_list expected_output_data, + std::initializer_list output_dims_data, + float output_min, float output_max, + uint8_t* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + + constexpr int inputs_size = 1; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateQuantizedTensor(input_data, input_dims, "input_tensor", input_min, + input_max), + CreateQuantizedTensor(output_data, output_dims, "output_tensor", + output_min, output_max), + }; + + TfLiteContext context; + PopulateContext(tensors, tensors_size, &context); + + ::tflite::ops::micro::AllOpsResolver resolver; + const TfLiteRegistration* registration = + resolver.FindOp(tflite::BuiltinOperator_SOFTMAX, 1); + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + + TfLiteSoftmaxParams builtin_data = {1.0f}; + const char* init_data = reinterpret_cast(&builtin_data); + size_t init_data_size = 0; + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context, init_data, init_data_size); + } + + int inputs_array_data[] = {1, 0}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 1}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + int temporaries_array_data[] = {0}; + TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data); + + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.temporaries = temporaries_array; + node.user_data = user_data; + node.builtin_data = reinterpret_cast(&builtin_data); + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + node.delegate = nullptr; + + if (registration->prepare) { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); + } + TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + if (registration->free) { + registration->free(&context, user_data); + } + for (int i = 0; i < output_dims_count; ++i) { + TF_LITE_MICRO_EXPECT_EQ(expected_output_data.begin()[i], output_data[i]); + } +} + +} // namespace +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(SimpleTest) { + const int output_dims_count = 10; + float output_data[output_dims_count]; + tflite::testing::TestSoftmaxFloat( // + {2, 2, 5}, // Input shape. + { + 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 + -1.0, -2.0, -3.0, -4.0, -5.0, // b = 0 + }, + { + // Expected results. + 0.011656231, + 0.031684921, + 0.086128544, + 0.234121657, + 0.636408647, + 0.636408647, + 0.234121657, + 0.086128544, + 0.031684921, + 0.011656231, + }, + {2, 2, 5}, // Output shape. + output_data); +} + +TF_LITE_MICRO_TEST(SimpleTestQuantized) { + using tflite::testing::F2Q; + + const float input_min = -63.5f; + const float input_max = 64.0f; + const float output_min = 0.0f; + const float output_max = (255.0f / 256.0f); + const int output_dims_count = 5; + uint8_t output_data[output_dims_count]; + tflite::testing::TestSoftmaxQuantized( // + {2, 1, 5}, // Input shape. + { + F2Q(1.0, input_min, input_max), + F2Q(2.0, input_min, input_max), + F2Q(3.0, input_min, input_max), + F2Q(4.0, input_min, input_max), + F2Q(5.0, input_min, input_max), + }, + input_min, input_max, // Input quantized range. + { + // Expected results. + F2Q(0.011656231, output_min, output_max), + F2Q(0.031684921, output_min, output_max), + F2Q(0.086128544, output_min, output_max), + F2Q(0.234121657, output_min, output_max), + F2Q(0.636408647, output_min, output_max), + }, + {2, 1, 5}, // Output shape. + output_min, output_max, // Output quantized range. + output_data); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h b/tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..789a48ece8bd68544649fb05548355cb796ccabb --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h @@ -0,0 +1,170 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_ + +#include +#include +#include + +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include "tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h" +#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h" + +namespace tflite { +namespace testing { + +// How many elements are in the array with this shape. +inline int ElementCount(const TfLiteIntArray& dims) { + int result = 1; + for (int i = 0; i < dims.size; ++i) { + result *= dims.data[i]; + } + return result; +} + +// Wrapper to forward kernel errors to the interpreter's error reporter. +inline void ReportOpError(struct TfLiteContext* context, const char* format, + ...) { + ErrorReporter* error_reporter = static_cast(context->impl_); + va_list args; + va_start(args, format); + error_reporter->Report(format, args); + va_end(args); +} + +// Derives the quantization scaling factor from a min and max range. +template +inline float ScaleFromMinMax(const float min, const float max) { + return (max - min) / ((std::numeric_limits::max() * 1.0) - + std::numeric_limits::min()); +} + +// Derives the quantization zero point from a min and max range. +template +inline int ZeroPointFromMinMax(const float min, const float max) { + return static_cast((-min / ScaleFromMinMax(min, max)) + 0.5f); +} + +// Converts a float value into an unsigned eight-bit quantized value. +inline uint8_t F2Q(const float value, const float min, const float max) { + int32_t result = ZeroPointFromMinMax(min, max) + + (value / ScaleFromMinMax(min, max)) + 0.5f; + if (result < 0) { + result = 0; + } + if (result > 256) { + result = 256; + } + return result; +} + +// Converts a float value into a signed thirty-two-bit quantized value. +inline uint8_t F2Q32(const float value, const float min, const float max) { + return static_cast((value - ZeroPointFromMinMax(min, max)) / + ScaleFromMinMax(min, max)); +} + +inline void PopulateContext(TfLiteTensor* tensors, int tensors_size, + TfLiteContext* context) { + context->tensors_size = tensors_size; + context->tensors = tensors; + context->impl_ = static_cast(micro_test::reporter); + context->GetExecutionPlan = nullptr; + context->ResizeTensor = nullptr; + context->ReportError = ReportOpError; + context->AddTensors = nullptr; + context->GetNodeAndRegistration = nullptr; + context->ReplaceSubgraphsWithDelegateKernels = nullptr; + context->recommended_num_threads = 1; + context->GetExternalContext = nullptr; + context->SetExternalContext = nullptr; +} + +inline TfLiteIntArray* IntArrayFromInts(const int* int_array) { + return const_cast( + reinterpret_cast(int_array)); +} + +inline TfLiteIntArray* IntArrayFromInitializer( + std::initializer_list int_initializer) { + return IntArrayFromInts(int_initializer.begin()); +} + +inline TfLiteTensor CreateFloatTensor(const float* data, TfLiteIntArray* dims, + const char* name) { + const size_t bytes = ElementCount(*dims) * sizeof(float); + return { + kTfLiteFloat32, {const_cast(reinterpret_cast(data))}, + dims, {}, + kTfLiteMemNone, bytes, + nullptr, name}; +} + +inline TfLiteTensor CreateFloatTensor(std::initializer_list data, + TfLiteIntArray* dims, const char* name) { + return CreateFloatTensor(data.begin(), dims, name); +} + +inline TfLiteTensor CreateQuantizedTensor(const uint8_t* data, + TfLiteIntArray* dims, + const char* name, float min, + float max) { + const size_t bytes = ElementCount(*dims) * sizeof(uint8_t); + const TfLiteQuantizationParams q_params = { + ScaleFromMinMax(min, max), + ZeroPointFromMinMax(min, max)}; + return { + kTfLiteUInt8, {const_cast(reinterpret_cast(data))}, + dims, q_params, + kTfLiteMemNone, bytes, + nullptr, name}; +} + +inline TfLiteTensor CreateQuantizedTensor(std::initializer_list data, + TfLiteIntArray* dims, + const char* name, float min, + float max) { + return CreateQuantizedTensor(data.begin(), dims, name, min, max); +} + +inline TfLiteTensor CreateQuantized32Tensor(const int32_t* data, + TfLiteIntArray* dims, + const char* name, float min, + float max) { + const size_t bytes = ElementCount(*dims) * sizeof(int32_t); + const TfLiteQuantizationParams q_params = { + ScaleFromMinMax(min, max), + ZeroPointFromMinMax(min, max)}; + return { + kTfLiteUInt8, {const_cast(reinterpret_cast(data))}, + dims, q_params, + kTfLiteMemNone, bytes, + nullptr, name}; +} + +inline TfLiteTensor CreateQuantized32Tensor(std::initializer_list data, + TfLiteIntArray* dims, + const char* name, float min, + float max) { + return CreateQuantized32Tensor(data.begin(), dims, name, min, max); +} + +} // namespace testing +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.cc b/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.cc new file mode 100644 index 0000000000000000000000000000000000000000..99dd8836611c287b7f76104c29c12a73d219ccb3 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.cc @@ -0,0 +1,78 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h" + +#ifdef TF_LITE_MCU_DEBUG_LOG +#include +#else // TF_LITE_MCU_DEBUG_LOG +#include +#include +void DebugLog(const char* s) { fprintf(stderr, "%s", s); } +void DebugLogInt32(int32_t i) { fprintf(stderr, "%d", i); } +void DebugLogUInt32(uint32_t i) { fprintf(stderr, "%d", i); } +void DebugLogHex(uint32_t i) { fprintf(stderr, "0x%8x", i); } +void DebugLogFloat(float i) { fprintf(stderr, "%f", i); } +#endif // TF_LITE_MCU_DEBUG_LOG + +namespace tflite { +namespace { +void DebugLogPrintf(const char* format, va_list args) { + const int output_cache_size = 64; + char output_cache[output_cache_size + 1]; + int output_cache_index = 0; + const char* current = format; + while (*current != 0) { + if (*current == '%') { + const char next = *(current + 1); + if ((next == 'd') || (next == 's')) { + current += 1; + if (output_cache_index > 0) { + output_cache[output_cache_index] = 0; + DebugLog(output_cache); + output_cache_index = 0; + } + if (next == 'd') { + DebugLogInt32(va_arg(args, int)); + } else if (next == 's') { + DebugLog(va_arg(args, char*)); + } + } + } else { + output_cache[output_cache_index] = *current; + output_cache_index += 1; + } + if (output_cache_index >= output_cache_size) { + output_cache[output_cache_index] = 0; + DebugLog(output_cache); + output_cache_index = 0; + } + current += 1; + } + if (output_cache_index > 0) { + output_cache[output_cache_index] = 0; + DebugLog(output_cache); + output_cache_index = 0; + } + DebugLog("\n"); +} +} // namespace + +int MicroErrorReporter::Report(const char* format, va_list args) { + DebugLogPrintf(format, args); + return 0; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h b/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h new file mode 100644 index 0000000000000000000000000000000000000000..33e54f7990af6cff4f8706d2889c335087581af4 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_ERROR_REPORTER_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_ERROR_REPORTER_H_ + +#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include "tensorflow/contrib/lite/experimental/micro/compatibility.h" + +namespace tflite { + +class MicroErrorReporter : public ErrorReporter { + public: + ~MicroErrorReporter() {} + int Report(const char* format, va_list args) override; + + private: + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_ERROR_REPORTER_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/micro_error_reporter_test.cc b/tensorflow/contrib/lite/experimental/micro/micro_error_reporter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ef3c32050c0e826c005f185553974170da7e486a --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/micro_error_reporter_test.cc @@ -0,0 +1,25 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h" + +int main(int argc, char** argv) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + error_reporter->Report("Number: %d", 42); + error_reporter->Report("Badly-formed format string %"); + error_reporter->Report("Another % badly-formed %% format string"); + error_reporter->Report("~~~%s~~~", "ALL TESTS PASSED"); +} diff --git a/tensorflow/contrib/lite/experimental/micro/micro_interpreter.cc b/tensorflow/contrib/lite/experimental/micro/micro_interpreter.cc new file mode 100644 index 0000000000000000000000000000000000000000..0f38991bb0ef3d0134b4d9a1eb6e148a140fe6f9 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/micro_interpreter.cc @@ -0,0 +1,310 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/micro/micro_interpreter.h" + +#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h" +#include "tensorflow/contrib/lite/experimental/micro/compatibility.h" + +namespace tflite { +namespace { +const int kStackDataAllocatorSize = 128; +class StackDataAllocator : public BuiltinDataAllocator { + public: + void* Allocate(size_t size) override { + if (size > kStackDataAllocatorSize) { + return nullptr; + } else { + return data_; + } + } + void Deallocate(void* data) override { + // Do nothing. + } + + private: + uint8_t data_[kStackDataAllocatorSize]; + + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +const char* OpNameFromRegistration(const TfLiteRegistration* registration) { + if (registration->builtin_code == BuiltinOperator_CUSTOM) { + return registration->custom_name; + } else { + return EnumNameBuiltinOperator(BuiltinOperator(registration->builtin_code)); + } +} + +void ReportOpError(struct TfLiteContext* context, const char* format, ...) { + MicroInterpreter* interpreter = + static_cast(context->impl_); + va_list args; + va_start(args, format); + interpreter->error_reporter()->Report(format, args); + va_end(args); +} + +} // namespace + +MicroInterpreter::MicroInterpreter(const Model* model, + const OpResolver& op_resolver, + SimpleTensorAllocator* tensor_allocator, + ErrorReporter* error_reporter) + : model_(model), + op_resolver_(op_resolver), + tensor_allocator_(tensor_allocator), + error_reporter_(error_reporter), + initialization_status_(kTfLiteOk) { + const flatbuffers::Vector>* buffers = + model->buffers(); + auto* subgraphs = model->subgraphs(); + if (subgraphs->size() != 1) { + error_reporter->Report("Only 1 subgraph is currently supported.\n"); + initialization_status_ = kTfLiteError; + return; + } + subgraph_ = (*subgraphs)[0]; + tensors_ = subgraph_->tensors(); + operators_ = subgraph_->operators(); + + context_.tensors_size = tensors_->Length(); + context_.tensors = + reinterpret_cast(tensor_allocator_->AllocateMemory( + sizeof(TfLiteTensor) * context_.tensors_size)); + for (int i = 0; i < subgraph_->inputs()->Length(); ++i) { + const int tensor_index = subgraph_->inputs()->Get(i); + const auto* tensor = tensors_->Get(tensor_index); + initialization_status_ = tensor_allocator_->AllocateTensor( + *tensor, 0, operators_->Length(), buffers, error_reporter, + &context_.tensors[tensor_index]); + if (initialization_status_ != kTfLiteOk) { + return; + } + } + + int* first_created = reinterpret_cast( + tensor_allocator_->AllocateMemory(sizeof(int) * tensors_->Length())); + int* last_used = reinterpret_cast( + tensor_allocator_->AllocateMemory(sizeof(int) * tensors_->Length())); + for (int i = 0; i < tensors_->Length(); ++i) { + first_created[i] = -1; + last_used[i] = -1; + } + + for (int i = (operators_->Length() - 1); i >= 0; --i) { + const auto* op = operators_->Get(i); + for (int n = 0; n < op->inputs()->Length(); ++n) { + const int tensor_index = op->inputs()->Get(n); + if ((last_used[tensor_index] == -1) || (last_used[tensor_index] < i)) { + last_used[tensor_index] = i; + } + } + for (int n = 0; n < op->outputs()->Length(); ++n) { + const int tensor_index = op->outputs()->Get(n); + const int create_before = i; + int destroy_after = last_used[tensor_index]; + if (destroy_after == -1) { + destroy_after = operators_->Length(); + } + const auto* tensor = tensors_->Get(tensor_index); + if (!tensor->is_variable()) { + initialization_status_ = tensor_allocator_->AllocateTensor( + *tensor, create_before, destroy_after, buffers, error_reporter, + &context_.tensors[tensor_index]); + if (initialization_status_ != kTfLiteOk) { + return; + } + first_created[tensor_index] = i; + } + } + } + + for (int i = 0; i < tensors_->Length(); ++i) { + const auto* tensor = tensors_->Get(i); + const bool is_read_only = (first_created[i] == -1) && (last_used[i] != -1); + if (tensor->is_variable() || is_read_only) { + initialization_status_ = tensor_allocator_->AllocateTensor( + *tensor, 0, operators_->Length(), buffers, error_reporter, + &context_.tensors[i]); + if (initialization_status_ != kTfLiteOk) { + return; + } + } + } + context_.impl_ = static_cast(this); + context_.GetExecutionPlan = nullptr; + context_.ResizeTensor = nullptr; + context_.ReportError = ReportOpError; + context_.AddTensors = nullptr; + context_.GetNodeAndRegistration = nullptr; + context_.ReplaceSubgraphsWithDelegateKernels = nullptr; + context_.recommended_num_threads = 1; + context_.GetExternalContext = nullptr; + context_.SetExternalContext = nullptr; +} + +TfLiteStatus MicroInterpreter::Invoke() { + if (initialization_status_ != kTfLiteOk) { + error_reporter_->Report("Invoke() called after initialization failed\n"); + return kTfLiteError; + } + TfLiteStatus status = kTfLiteOk; + auto opcodes = model_->operator_codes(); + for (int i = 0; i < operators_->Length(); ++i) { + const auto* op = operators_->Get(i); + int index = op->opcode_index(); + if (index < 0 || index >= opcodes->size()) { + error_reporter_->Report("Missing registration for opcode_index %d\n", + index); + return kTfLiteError; + } + auto opcode = (*opcodes)[index]; + const TfLiteRegistration* registration = nullptr; + status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_, + ®istration); + if (status != kTfLiteOk) { + return status; + } + if (registration == nullptr) { + error_reporter_->Report("Skipping op for opcode_index %d\n", index); + return kTfLiteError; + } + BuiltinOperator op_type = + static_cast(registration->builtin_code); + + if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) { + error_reporter_->Report( + "Found builtin operator %s with custom options.\n", + EnumNameBuiltinOperator(op_type)); + } + StackDataAllocator stack_data_allocator; + const char* custom_data = nullptr; + size_t custom_data_size = 0; + unsigned char* builtin_data = nullptr; + if (op->custom_options()) { + custom_data = reinterpret_cast(op->custom_options()->data()); + custom_data_size = op->custom_options()->size(); + } else { + TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_, + &stack_data_allocator, + (void**)(&builtin_data))); + } + + const char* init_data; + size_t init_data_size; + if (registration->builtin_code == BuiltinOperator_CUSTOM) { + init_data = custom_data; + init_data_size = custom_data_size; + } else { + init_data = reinterpret_cast(builtin_data); + init_data_size = 0; + } + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context_, init_data, init_data_size); + } + + const int kMaxInputs = 16; + int inputs_data[kMaxInputs + 1]; + TfLiteIntArray* inputs_array = + reinterpret_cast(inputs_data); + if (op->inputs()->Length() >= kMaxInputs) { + error_reporter_->Report("Too many inputs (%d)\n", op->inputs()->Length()); + return kTfLiteError; + } + inputs_array->size = op->inputs()->Length(); + for (int n = 0; n < op->inputs()->Length(); ++n) { + inputs_array->data[n] = op->inputs()->Get(n); + } + + const int kMaxOutputs = 16; + int outputs_data[kMaxOutputs + 1]; + TfLiteIntArray* outputs_array = + reinterpret_cast(outputs_data); + if (op->outputs()->Length() >= kMaxOutputs) { + error_reporter_->Report("Too many outputs (%d)\n", + op->outputs()->Length()); + return kTfLiteError; + } + outputs_array->size = op->outputs()->Length(); + for (int n = 0; n < op->outputs()->Length(); ++n) { + outputs_array->data[n] = op->outputs()->Get(n); + } + + const int kMaxTemporaries = 16; + int temporaries_data[kMaxTemporaries + 1]; + TfLiteIntArray* temporaries_array = + reinterpret_cast(temporaries_data); + temporaries_array->size = 0; + + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.temporaries = temporaries_array; + node.user_data = user_data; + node.builtin_data = reinterpret_cast(builtin_data); + node.custom_initial_data = custom_data; + node.custom_initial_data_size = custom_data_size; + node.delegate = nullptr; + if (registration->prepare) { + TfLiteStatus prepare_status = registration->prepare(&context_, &node); + if (prepare_status != kTfLiteOk) { + error_reporter_->Report( + "Node %s (number %d) failed to prepare with status %d", + OpNameFromRegistration(registration), i, prepare_status); + return kTfLiteError; + } + } + + if (registration->invoke) { + TfLiteStatus invoke_status = registration->invoke(&context_, &node); + if (invoke_status != kTfLiteOk) { + error_reporter_->Report( + "Node %s (number %d) failed to invoke with status %d", + OpNameFromRegistration(registration), i, invoke_status); + return kTfLiteError; + } + } + + if (registration->free) { + registration->free(&context_, user_data); + } + } + return status; +} + +TfLiteTensor* MicroInterpreter::input(int index) { + const flatbuffers::Vector* inputs = subgraph_->inputs(); + const size_t length = inputs->Length(); + if ((index < 0) || (index >= length)) { + error_reporter_->Report("Input index %d out of range (length is %d)", index, + length); + return nullptr; + } + return &(context_.tensors[inputs->Get(index)]); +} + +TfLiteTensor* MicroInterpreter::output(int index) { + const flatbuffers::Vector* outputs = subgraph_->outputs(); + const size_t length = outputs->Length(); + if ((index < 0) || (index >= outputs->Length())) { + error_reporter_->Report("Output index %d out of range (length is %d)", + index, length); + return nullptr; + } + return &(context_.tensors[outputs->Get(index)]); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/experimental/micro/micro_interpreter.h b/tensorflow/contrib/lite/experimental/micro/micro_interpreter.h new file mode 100644 index 0000000000000000000000000000000000000000..a88514cde849595244d36a31900e6d1c2ae1714b --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/micro_interpreter.h @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_INTERPRETER_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_INTERPRETER_H_ + +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include "tensorflow/contrib/lite/core/api/op_resolver.h" +#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" + +namespace tflite { + +class MicroInterpreter { + public: + // The lifetime of the model, op resolver, allocator, and error reporter must + // be at least as long as that of the interpreter object, since the + // interpreter may need to access them at any time. This means that you should + // usually create them with the same scope as each other, for example having + // them all allocated on the stack as local variables through a top-level + // function. + // The interpreter doesn't do any deallocation of any of the pointed-to + // objects, ownership remains with the caller. + MicroInterpreter(const Model* model, const OpResolver& op_resolver, + SimpleTensorAllocator* tensor_allocator, + ErrorReporter* error_reporter); + + TfLiteStatus Invoke(); + + size_t tensors_size() const { return context_.tensors_size; } + TfLiteTensor* tensor(int tensor_index); + + TfLiteTensor* input(int index); + size_t inputs_size() const { return subgraph_->inputs()->Length(); } + + TfLiteTensor* output(int index); + size_t outputs_size() const { return subgraph_->outputs()->Length(); } + + TfLiteStatus initialization_status() const { return initialization_status_; } + + ErrorReporter* error_reporter() { return error_reporter_; } + + private: + const Model* model_; + const OpResolver& op_resolver_; + SimpleTensorAllocator* tensor_allocator_; + ErrorReporter* error_reporter_; + + TfLiteStatus initialization_status_; + const flatbuffers::Vector>* tensors_; + const flatbuffers::Vector>* operators_; + TfLiteContext context_; + + const SubGraph* subgraph_; +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_INTERPRETER_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/micro_interpreter_test.cc b/tensorflow/contrib/lite/experimental/micro/micro_interpreter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..251e5f72037717f74bc3472b69144cff299f0668 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/micro_interpreter_test.cc @@ -0,0 +1,197 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/micro/micro_interpreter.h" + +#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h" + +namespace tflite { +namespace { +void* MockInit(TfLiteContext* context, const char* buffer, size_t length) { + // Do nothing. + return nullptr; +} + +void MockFree(TfLiteContext* context, void* buffer) { + // Do nothing. +} + +TfLiteStatus MockPrepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteStatus MockInvoke(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = &context->tensors[node->inputs->data[0]]; + const int32_t* input_data = input->data.i32; + const TfLiteTensor* weight = &context->tensors[node->inputs->data[1]]; + const uint8_t* weight_data = weight->data.uint8; + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + int32_t* output_data = output->data.i32; + output_data[0] = input_data[0] + weight_data[0]; + return kTfLiteOk; +} + +class MockOpResolver : public OpResolver { + public: + const TfLiteRegistration* FindOp(BuiltinOperator op, + int version) const override { + return nullptr; + } + const TfLiteRegistration* FindOp(const char* op, int version) const override { + if (strcmp(op, "mock_custom") == 0) { + static TfLiteRegistration r = {MockInit, MockFree, MockPrepare, + MockInvoke}; + return &r; + } else { + return nullptr; + } + } +}; + +class StackAllocator : public flatbuffers::Allocator { + public: + StackAllocator() : data_(data_backing_), data_size_(0) {} + + uint8_t* allocate(size_t size) override { + if ((data_size_ + size) > kStackAllocatorSize) { + // TODO(petewarden): Add error reporting beyond returning null! + return nullptr; + } + uint8_t* result = data_; + data_ += size; + data_size_ += size; + return result; + } + + void deallocate(uint8_t* p, size_t) override {} + + static StackAllocator& instance() { + // Avoid using true dynamic memory allocation to be portable to bare metal. + static char inst_memory[sizeof(StackAllocator)]; + static StackAllocator* inst = new (inst_memory) StackAllocator; + return *inst; + } + + static constexpr int kStackAllocatorSize = 4096; + + private: + uint8_t data_backing_[kStackAllocatorSize]; + uint8_t* data_; + int data_size_; +}; + +const Model* BuildMockModel() { + using flatbuffers::Offset; + flatbuffers::FlatBufferBuilder builder(StackAllocator::kStackAllocatorSize, + &StackAllocator::instance()); + constexpr size_t buffer_data_size = 1; + const uint8_t buffer_data[buffer_data_size] = {21}; + constexpr size_t buffers_size = 2; + const Offset buffers[buffers_size] = { + CreateBuffer(builder), + CreateBuffer(builder, + builder.CreateVector(buffer_data, buffer_data_size))}; + constexpr size_t tensor_shape_size = 1; + const int32_t tensor_shape[tensor_shape_size] = {1}; + constexpr size_t tensors_size = 3; + const Offset tensors[tensors_size] = { + CreateTensor(builder, + builder.CreateVector(tensor_shape, tensor_shape_size), + TensorType_INT32, 0, + builder.CreateString("test_input_tensor"), 0, false), + CreateTensor(builder, + builder.CreateVector(tensor_shape, tensor_shape_size), + TensorType_UINT8, 1, + builder.CreateString("test_weight_tensor"), 0, false), + CreateTensor(builder, + builder.CreateVector(tensor_shape, tensor_shape_size), + TensorType_INT32, 0, + builder.CreateString("test_output_tensor"), 0, false), + }; + constexpr size_t inputs_size = 1; + const int32_t inputs[inputs_size] = {0}; + constexpr size_t outputs_size = 1; + const int32_t outputs[outputs_size] = {2}; + constexpr size_t operator_inputs_size = 2; + const int32_t operator_inputs[operator_inputs_size] = {0, 1}; + constexpr size_t operator_outputs_size = 1; + const int32_t operator_outputs[operator_outputs_size] = {2}; + constexpr size_t operators_size = 1; + const Offset operators[operators_size] = {CreateOperator( + builder, 0, builder.CreateVector(operator_inputs, operator_inputs_size), + builder.CreateVector(operator_outputs, operator_outputs_size), + BuiltinOptions_NONE)}; + constexpr size_t subgraphs_size = 1; + const Offset subgraphs[subgraphs_size] = { + CreateSubGraph(builder, builder.CreateVector(tensors, tensors_size), + builder.CreateVector(inputs, inputs_size), + builder.CreateVector(outputs, outputs_size), + builder.CreateVector(operators, operators_size), + builder.CreateString("test_subgraph"))}; + constexpr size_t operator_codes_size = 1; + const Offset operator_codes[operator_codes_size] = { + CreateOperatorCodeDirect(builder, BuiltinOperator_CUSTOM, "mock_custom", + 0)}; + const Offset model_offset = CreateModel( + builder, 0, builder.CreateVector(operator_codes, operator_codes_size), + builder.CreateVector(subgraphs, subgraphs_size), + builder.CreateString("test_model"), + builder.CreateVector(buffers, buffers_size)); + FinishModelBuffer(builder, model_offset); + void* model_pointer = builder.GetBufferPointer(); + const Model* model = flatbuffers::GetRoot(model_pointer); + return model; +} + +} // namespace +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(TestInterpreter) { + const tflite::Model* model = tflite::BuildMockModel(); + TF_LITE_MICRO_EXPECT_NE(nullptr, model); + tflite::MockOpResolver mock_resolver; + constexpr size_t allocator_buffer_size = 1024; + uint8_t allocator_buffer[allocator_buffer_size]; + tflite::SimpleTensorAllocator simple_tensor_allocator(allocator_buffer, + allocator_buffer_size); + tflite::MicroInterpreter interpreter( + model, mock_resolver, &simple_tensor_allocator, micro_test::reporter); + TF_LITE_MICRO_EXPECT_EQ(1, interpreter.inputs_size()); + TF_LITE_MICRO_EXPECT_EQ(1, interpreter.outputs_size()); + + TfLiteTensor* input = interpreter.input(0); + TF_LITE_MICRO_EXPECT_NE(nullptr, input); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, input->type); + TF_LITE_MICRO_EXPECT_EQ(1, input->dims->size); + TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]); + TF_LITE_MICRO_EXPECT_EQ(4, input->bytes); + TF_LITE_MICRO_EXPECT_NE(nullptr, input->data.i32); + input->data.i32[0] = 21; + + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.Invoke()); + + TfLiteTensor* output = interpreter.output(0); + TF_LITE_MICRO_EXPECT_NE(nullptr, output); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, output->type); + TF_LITE_MICRO_EXPECT_EQ(1, output->dims->size); + TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]); + TF_LITE_MICRO_EXPECT_EQ(4, output->bytes); + TF_LITE_MICRO_EXPECT_NE(nullptr, output->data.i32); + TF_LITE_MICRO_EXPECT_EQ(42, output->data.i32[0]); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.cc b/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.cc new file mode 100644 index 0000000000000000000000000000000000000000..40c21c6448c39f27c12e95ae36038510cb346362 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.cc @@ -0,0 +1,80 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h" + +namespace tflite { + +const TfLiteRegistration* MicroMutableOpResolver::FindOp( + tflite::BuiltinOperator op, int version) const { + for (int i = 0; i < registrations_len_; ++i) { + const TfLiteRegistration& registration = registrations_[i]; + if ((registration.builtin_code == op) && + (registration.version == version)) { + return ®istration; + } + } + return nullptr; +} + +const TfLiteRegistration* MicroMutableOpResolver::FindOp(const char* op, + int version) const { + for (int i = 0; i < registrations_len_; ++i) { + const TfLiteRegistration& registration = registrations_[i]; + if ((registration.builtin_code == -1) && + (strcmp(registration.custom_name, op) == 0) && + (registration.version == version)) { + return ®istration; + } + } + return nullptr; +} + +void MicroMutableOpResolver::AddBuiltin(tflite::BuiltinOperator op, + TfLiteRegistration* registration, + int min_version, int max_version) { + for (int version = min_version; version <= max_version; ++version) { + if (registrations_len_ >= TFLITE_REGISTRATIONS_MAX) { + // TODO(petewarden) - Add error reporting hooks so we can report this! + return; + } + TfLiteRegistration* new_registration = ®istrations_[registrations_len_]; + registrations_len_ += 1; + + *new_registration = *registration; + new_registration->builtin_code = op; + new_registration->version = version; + } +} + +void MicroMutableOpResolver::AddCustom(const char* name, + TfLiteRegistration* registration, + int min_version, int max_version) { + for (int version = min_version; version <= max_version; ++version) { + if (registrations_len_ >= TFLITE_REGISTRATIONS_MAX) { + // TODO(petewarden) - Add error reporting hooks so we can report this! + return; + } + TfLiteRegistration* new_registration = ®istrations_[registrations_len_]; + registrations_len_ += 1; + + *new_registration = *registration; + new_registration->builtin_code = -1; + new_registration->custom_name = name; + new_registration->version = version; + } +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h b/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h new file mode 100644 index 0000000000000000000000000000000000000000..f3750a248416cc7244e0dea82be167562fd59ee7 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ + +#include "tensorflow/contrib/lite/core/api/op_resolver.h" +#include "tensorflow/contrib/lite/experimental/micro/compatibility.h" + +#ifndef TFLITE_REGISTRATIONS_MAX +#define TFLITE_REGISTRATIONS_MAX (128) +#endif + +namespace tflite { + +class MicroMutableOpResolver : public OpResolver { + public: + const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, + int version) const override; + const TfLiteRegistration* FindOp(const char* op, int version) const override; + void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration, + int min_version = 1, int max_version = 1); + void AddCustom(const char* name, TfLiteRegistration* registration, + int min_version = 1, int max_version = 1); + + private: + TfLiteRegistration registrations_[TFLITE_REGISTRATIONS_MAX]; + int registrations_len_ = 0; + + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver_test.cc b/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5420a33e8778d93d5aad2150438fdba80df372b8 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver_test.cc @@ -0,0 +1,83 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h" + +#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h" + +namespace tflite { +namespace { +void* MockInit(TfLiteContext* context, const char* buffer, size_t length) { + // Do nothing. + return nullptr; +} + +void MockFree(TfLiteContext* context, void* buffer) { + // Do nothing. +} + +TfLiteStatus MockPrepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteStatus MockInvoke(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} +} // namespace +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(TestOperations) { + using tflite::BuiltinOperator_CONV_2D; + using tflite::BuiltinOperator_RELU; + using tflite::MicroMutableOpResolver; + using tflite::OpResolver; + + static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree, + tflite::MockPrepare, tflite::MockInvoke}; + + MicroMutableOpResolver micro_mutable_op_resolver; + micro_mutable_op_resolver.AddBuiltin(BuiltinOperator_CONV_2D, &r, 0, 2); + micro_mutable_op_resolver.AddCustom("mock_custom", &r, 0, 3); + OpResolver* resolver = µ_mutable_op_resolver; + + const TfLiteRegistration* registration = + resolver->FindOp(BuiltinOperator_CONV_2D, 0); + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + TF_LITE_MICRO_EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0)); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr)); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr)); + + registration = resolver->FindOp(BuiltinOperator_CONV_2D, 10); + TF_LITE_MICRO_EXPECT_EQ(nullptr, registration); + + registration = resolver->FindOp(BuiltinOperator_RELU, 0); + TF_LITE_MICRO_EXPECT_EQ(nullptr, registration); + + registration = resolver->FindOp("mock_custom", 0); + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + TF_LITE_MICRO_EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0)); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr)); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr)); + + registration = resolver->FindOp("mock_custom", 10); + TF_LITE_MICRO_EXPECT_EQ(nullptr, registration); + + registration = resolver->FindOp("nonexistent_custom", 0); + TF_LITE_MICRO_EXPECT_EQ(nullptr, registration); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.cc b/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.cc new file mode 100644 index 0000000000000000000000000000000000000000..8c090a20a5fb9e6cb330a40c86236c549c28539e --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.cc @@ -0,0 +1,149 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h" + +#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h" + +namespace tflite { +namespace { + +TfLiteStatus TfLiteTypeSizeOf(TfLiteType type, size_t* size, + ErrorReporter* reporter) { + switch (type) { + case kTfLiteFloat32: + *size = sizeof(float); + break; + case kTfLiteInt16: + *size = sizeof(int16_t); + break; + case kTfLiteInt32: + *size = sizeof(int32_t); + break; + case kTfLiteUInt8: + *size = sizeof(uint8_t); + break; + case kTfLiteInt64: + *size = sizeof(int64_t); + break; + case kTfLiteBool: + *size = sizeof(bool); + break; + case kTfLiteComplex64: + *size = sizeof(float) * 2; + break; + default: + reporter->Report( + "Only float32, int16, int32, int64, uint8, bool, complex64 " + "supported currently."); + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus BytesRequired(const tflite::Tensor& flatbuffer_tensor, + size_t dims_size, size_t* bytes, + ErrorReporter* error_reporter) { + TfLiteType tf_lite_type; + TF_LITE_ENSURE_STATUS(ConvertTensorType(flatbuffer_tensor.type(), + &tf_lite_type, error_reporter)); + size_t type_size; + TF_LITE_ENSURE_STATUS( + TfLiteTypeSizeOf(tf_lite_type, &type_size, error_reporter)); + *bytes = dims_size * type_size; + return kTfLiteOk; +} + +} // namespace + +TfLiteStatus SimpleTensorAllocator::AllocateTensor( + const tflite::Tensor& flatbuffer_tensor, int create_before, + int destroy_after, + const flatbuffers::Vector>* buffers, + ErrorReporter* error_reporter, TfLiteTensor* result) { + TF_LITE_ENSURE_STATUS(ConvertTensorType(flatbuffer_tensor.type(), + &result->type, error_reporter)); + result->is_variable = flatbuffer_tensor.is_variable(); + + result->data.raw = nullptr; + result->bytes = 0; + if (auto* buffer = (*buffers)[flatbuffer_tensor.buffer()]) { + if (auto* array = buffer->data()) { + if (size_t array_size = array->size()) { + result->data.raw = + const_cast(reinterpret_cast(array->data())); + TF_LITE_ENSURE_STATUS(BytesRequired(flatbuffer_tensor, array_size, + &result->bytes, error_reporter)); + } + } + } + if (result->data.raw) { + result->allocation_type = kTfLiteMmapRo; + } else { + int data_size = 1; + for (int n = 0; n < flatbuffer_tensor.shape()->Length(); ++n) { + data_size *= flatbuffer_tensor.shape()->Get(n); + } + TF_LITE_ENSURE_STATUS(BytesRequired(flatbuffer_tensor, data_size, + &result->bytes, error_reporter)); + result->data.raw = reinterpret_cast(AllocateMemory(result->bytes)); + if (result->data.raw == nullptr) { + const char* tensor_name = flatbuffer_tensor.name()->c_str(); + if (tensor_name == nullptr) { + tensor_name = ""; + } + error_reporter->Report( + "Couldn't allocate memory for tensor '%s', wanted %d bytes but only " + "%d were available", + tensor_name, result->bytes, (data_size_max_ - data_size_)); + return kTfLiteError; + } + result->allocation_type = kTfLiteArenaRw; + } + result->dims = reinterpret_cast( + AllocateMemory(sizeof(int) * (flatbuffer_tensor.shape()->Length() + 1))); + result->dims->size = flatbuffer_tensor.shape()->Length(); + for (int n = 0; n < flatbuffer_tensor.shape()->Length(); ++n) { + result->dims->data[n] = flatbuffer_tensor.shape()->Get(n); + } + if (flatbuffer_tensor.quantization()) { + result->params.scale = flatbuffer_tensor.quantization()->scale()->Get(0); + result->params.zero_point = + flatbuffer_tensor.quantization()->zero_point()->Get(0); + } + result->allocation = nullptr; + if (flatbuffer_tensor.name()) { + result->name = flatbuffer_tensor.name()->c_str(); + } else { + result->name = ""; + } + result->delegate = nullptr; + result->buffer_handle = 0; + result->data_is_stale = false; + return kTfLiteOk; +} + +uint8_t* SimpleTensorAllocator::AllocateMemory(size_t size) { + if ((data_size_ + size) > data_size_max_) { + // TODO(petewarden): Add error reporting beyond returning null! + return nullptr; + } + uint8_t* result = data_; + data_ += size; + data_size_ += size; + return result; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h b/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..4f16a9d0e54cba6fb3b635ceeb39ab10ff59ae73 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h @@ -0,0 +1,51 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_SIMPLE_TENSOR_ALLOCATOR_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_SIMPLE_TENSOR_ALLOCATOR_H_ + +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" + +namespace tflite { + +// TODO(petewarden): This allocator never frees up or reuses any memory, even +// though we have enough information about lifetimes of the tensors to do so. +// This makes it pretty wasteful, so we should use a more intelligent method. +class SimpleTensorAllocator { + public: + SimpleTensorAllocator(uint8_t* buffer, int buffer_size) + : data_size_(0), data_size_max_(buffer_size), data_(buffer) {} + + TfLiteStatus AllocateTensor( + const tflite::Tensor& flatbuffer_tensor, int create_before, + int destroy_after, + const flatbuffers::Vector>* buffers, + ErrorReporter* error_reporter, TfLiteTensor* result); + + uint8_t* AllocateMemory(size_t size); + + int GetDataSize() const { return data_size_; } + + private: + int data_size_; + int data_size_max_; + uint8_t* data_; +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_SIMPLE_TENSOR_ALLOCATOR_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator_test.cc b/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c83542724395328cb6a5e038b64dba4b9f4f655b --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator_test.cc @@ -0,0 +1,144 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/micro/micro_interpreter.h" + +#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h" + +namespace tflite { +namespace { +class StackAllocator : public flatbuffers::Allocator { + public: + StackAllocator() : data_(data_backing_), data_size_(0) {} + + uint8_t* allocate(size_t size) override { + if ((data_size_ + size) > kStackAllocatorSize) { + // TODO(petewarden): Add error reporting beyond returning null! + return nullptr; + } + uint8_t* result = data_; + data_ += size; + data_size_ += size; + return result; + } + + void deallocate(uint8_t* p, size_t) override {} + + static StackAllocator& instance() { + // Avoid using true dynamic memory allocation to be portable to bare metal. + static char inst_memory[sizeof(StackAllocator)]; + static StackAllocator* inst = new (inst_memory) StackAllocator; + return *inst; + } + + static constexpr int kStackAllocatorSize = 4096; + + private: + uint8_t data_backing_[kStackAllocatorSize]; + uint8_t* data_; + int data_size_; +}; + +flatbuffers::FlatBufferBuilder* BuilderInstance() { + static char inst_memory[sizeof(flatbuffers::FlatBufferBuilder)]; + static flatbuffers::FlatBufferBuilder* inst = + new (inst_memory) flatbuffers::FlatBufferBuilder( + StackAllocator::kStackAllocatorSize, &StackAllocator::instance()); + return inst; +} + +const Tensor* Create1dTensor(int size) { + using flatbuffers::Offset; + flatbuffers::FlatBufferBuilder* builder = BuilderInstance(); + constexpr size_t tensor_shape_size = 1; + const int32_t tensor_shape[tensor_shape_size] = {size}; + const Offset tensor_offset = CreateTensor( + *builder, builder->CreateVector(tensor_shape, tensor_shape_size), + TensorType_INT32, 0, builder->CreateString("test_tensor"), 0, false); + builder->Finish(tensor_offset); + void* tensor_pointer = builder->GetBufferPointer(); + const Tensor* tensor = flatbuffers::GetRoot(tensor_pointer); + return tensor; +} + +const flatbuffers::Vector>* CreateBuffers() { + using flatbuffers::Offset; + flatbuffers::FlatBufferBuilder* builder = BuilderInstance(); + constexpr size_t buffers_size = 1; + const Offset buffers[buffers_size] = { + CreateBuffer(*builder), + }; + const flatbuffers::Offset>> + buffers_offset = builder->CreateVector(buffers, buffers_size); + builder->Finish(buffers_offset); + void* buffers_pointer = builder->GetBufferPointer(); + const flatbuffers::Vector>* result = + flatbuffers::GetRoot>>( + buffers_pointer); + return result; +} + +} // namespace +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(TestAllocateTensor) { + constexpr size_t arena_size = 1024; + uint8_t arena[arena_size]; + tflite::SimpleTensorAllocator allocator(arena, arena_size); + + const tflite::Tensor* tensor = tflite::Create1dTensor(100); + const flatbuffers::Vector>* buffers = + tflite::CreateBuffers(); + + TfLiteTensor allocated_tensor; + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + allocator.AllocateTensor(*tensor, 0, 1, buffers, micro_test::reporter, + &allocated_tensor)); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, allocated_tensor.type); + TF_LITE_MICRO_EXPECT_EQ(1, allocated_tensor.dims->size); + TF_LITE_MICRO_EXPECT_EQ(100, allocated_tensor.dims->data[0]); + TF_LITE_MICRO_EXPECT_EQ(400, allocated_tensor.bytes); + TF_LITE_MICRO_EXPECT_NE(nullptr, allocated_tensor.data.i32); +} + +TF_LITE_MICRO_TEST(TestTooLarge) { + constexpr size_t arena_size = 1024; + uint8_t arena[arena_size]; + tflite::SimpleTensorAllocator allocator(arena, arena_size); + + const tflite::Tensor* tensor = tflite::Create1dTensor(10000); + const flatbuffers::Vector>* buffers = + tflite::CreateBuffers(); + + TfLiteTensor allocated_tensor; + TF_LITE_MICRO_EXPECT_NE( + kTfLiteOk, + allocator.AllocateTensor(*tensor, 0, 1, buffers, micro_test::reporter, + &allocated_tensor)); +} + +TF_LITE_MICRO_TEST(TestJustFits) { + constexpr size_t arena_size = 1024; + uint8_t arena[arena_size]; + tflite::SimpleTensorAllocator allocator(arena, arena_size); + + uint8_t* result = allocator.AllocateMemory(arena_size); + TF_LITE_MICRO_EXPECT_NE(nullptr, result); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/contrib/lite/experimental/micro/testing/BUILD b/tensorflow/contrib/lite/experimental/micro/testing/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..0d23be5712ad1bc6d81cc467cce8c9927caece3d --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/testing/BUILD @@ -0,0 +1,17 @@ +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["test_linux_binary.sh"]) + +cc_library( + name = "micro_test", + hdrs = [ + "micro_test.h", + ], + deps = [ + "//tensorflow/contrib/lite/experimental/micro:micro_framework", + ], +) diff --git a/tensorflow/contrib/lite/experimental/micro/testing/Dockerfile.bluepill b/tensorflow/contrib/lite/experimental/micro/testing/Dockerfile.bluepill new file mode 100644 index 0000000000000000000000000000000000000000..7d6d81af0f482afb7a9f0624b5262a5277112976 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/testing/Dockerfile.bluepill @@ -0,0 +1,21 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# This docker configuration file lets you emulate a Blue Pill board +# on an x86 desktop or laptop, which can be useful for debugging and +# automated testing. +FROM antmicro/renode:latest + +LABEL maintainer="Pete Warden " \ No newline at end of file diff --git a/tensorflow/contrib/lite/experimental/micro/testing/bluepill.resc b/tensorflow/contrib/lite/experimental/micro/testing/bluepill.resc new file mode 100644 index 0000000000000000000000000000000000000000..c46b33e3fb0b188c0c108e69ebc05063c0e00575 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/testing/bluepill.resc @@ -0,0 +1,33 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +using sysbus + +mach create +machine LoadPlatformDescription @platforms/cpus/stm32f103.repl + +# These lines are needed to show the results of DebugLog calls in the output. +machine LoadPlatformDescriptionFromString "uartSemihosting: UART.SemihostingUart @ cpu" +showAnalyzer cpu.uartSemihosting Antmicro.Renode.Analyzers.LoggingUartAnalyzer + +logFile @/tmp/renode_bluepill_log.txt + +macro reset +""" + sysbus LoadELF $bin +""" + +runMacro $reset + diff --git a/tensorflow/contrib/lite/experimental/micro/testing/bluepill.robot b/tensorflow/contrib/lite/experimental/micro/testing/bluepill.robot new file mode 100644 index 0000000000000000000000000000000000000000..f09c3a0cc0df841e3354e517a8416343df8726bd --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/testing/bluepill.robot @@ -0,0 +1,23 @@ +*** Settings *** +Suite Setup Setup +Suite Teardown Teardown +Test Setup Reset Emulation +Resource /opt/renode/tests/renode-keywords.robot + +*** Variables *** +${UART} sysbus.cpu.uartSemihosting + +*** Test Cases *** +Should Run Bluepill Test + [Documentation] Runs a Bluepill test and waits for a specific string on the semihosting UART + [Tags] bluepill uart tensorflow arm + ${BIN} = Get Environment Variable BIN + ${SCRIPT} = Get Environment Variable SCRIPT + ${EXPECTED} = Get Environment Variable EXPECTED + Execute Command $bin = @${BIN} + Execute Script ${SCRIPT} + + Create Terminal Tester ${UART} timeout=3 + Start Emulation + + Wait For Line On Uart ${EXPECTED} diff --git a/tensorflow/contrib/lite/experimental/micro/testing/micro_test.bzl b/tensorflow/contrib/lite/experimental/micro/testing/micro_test.bzl new file mode 100644 index 0000000000000000000000000000000000000000..916e3eeac394f9a815d7c1785d253fd54ca7aa0e --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/testing/micro_test.bzl @@ -0,0 +1,67 @@ +"""Rules for simple testing without dependencies by parsing output logs.""" + +def tflite_micro_cc_test( + name, + expected_in_logs = "~~~ALL TESTS PASSED~~~", + srcs = [], + includes = [], + defines = [], + copts = [], + nocopts = "", + linkopts = [], + deps = [], + tags = [], + visibility = None): + """Tests a C/C++ binary without testing framework dependencies`. + + Runs a C++ binary, and tests that the output logs contain the + expected value. This is a deliberately spartan way of testing, to match + what's available when testing microcontroller binaries. + + Args: + name: a unique name for this rule. + expected_in_logs: A regular expression that is required to be + present in the binary's logs for the test to pass. + srcs: sources to compile (C, C++, ld scripts). + includes: include paths to add to this rule and its dependents. + defines: list of `VAR` or `VAR=VAL` to pass to CPP for this rule and + its dependents. + copts: gcc compilation flags for this rule only. + nocopts: list of gcc compilation flags to remove for this rule + only. No regexp like for `cc_library`. + linkopts: `gcc` flags to add to the linking phase. For "pure" ld flags, + prefix them with the `-Wl,` prefix here. + deps: dependencies. only `tflite_bare_metal_cc_library()` dependencies + allowed. + visibility: visibility. + """ + native.cc_binary( + name = name + "_binary", + srcs = srcs, + includes = includes, + defines = defines, + copts = copts, + nocopts = nocopts, + linkopts = linkopts, + deps = deps, + tags = tags, + visibility = visibility, + ) + native.sh_test( + name = name, + size = "medium", + srcs = [ + "//tensorflow/contrib/lite/experimental/micro/testing:test_linux_binary.sh", + ], + args = [ + native.package_name() + "/" + name + "_binary", + "'" + expected_in_logs + "'", + ], + data = [ + name + "_binary", + # Internal test dependency placeholder + ], + deps = [ + ], + tags = tags, + ) diff --git a/tensorflow/contrib/lite/experimental/micro/testing/micro_test.h b/tensorflow/contrib/lite/experimental/micro/testing/micro_test.h new file mode 100644 index 0000000000000000000000000000000000000000..3b6554dea6a59feb2d2675ef58005e21a8001887 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/testing/micro_test.h @@ -0,0 +1,156 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// An ultra-lightweight testing framework designed for use with microcontroller +// applications. Its only dependency is on TensorFlow Lite's ErrorReporter +// interface, where log messages are output. This is designed to be usable even +// when no standard C or C++ libraries are available, and without any dynamic +// memory allocation or reliance on global constructors. +// +// To build a test, you use syntax similar to gunit, but with some extra +// decoration to create a hidden 'main' function containing each of the tests to +// be run. Your code should look something like: +// ---------------------------------------------------------------------------- +// #include "path/to/this/header" +// +// TF_LITE_MICRO_TESTS_BEGIN +// +// TF_LITE_MICRO_TEST(SomeTest) { +// TF_LITE_LOG_EXPECT_EQ(true, true); +// } +// +// TF_LITE_MICRO_TESTS_END +// ---------------------------------------------------------------------------- +// If you compile this for your platform, you'll get a normal binary that you +// should be able to run. Executing it will output logging information like this +// to stderr (or whatever equivalent is available and written to by +// ErrorReporter): +// ---------------------------------------------------------------------------- +// Testing SomeTest +// 1/1 tests passed +// ~~~ALL TESTS PASSED~~~ +// ---------------------------------------------------------------------------- +// This is designed to be human-readable, so you can just run tests manually, +// but the string "~~~ALL TESTS PASSED~~~" should only appear if all of the +// tests do pass. This makes it possible to integrate with automated test +// systems by scanning the output logs and looking for that magic value. +// +// This framework is intended to be a rudimentary alternative to no testing at +// all on systems that struggle to run more conventional approaches, so use with +// caution! + +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_TESTING_MICRO_TEST_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_TESTING_MICRO_TEST_H_ + +#include "tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h" + +namespace micro_test { +extern int tests_passed; +extern int tests_failed; +extern bool is_test_complete; +extern bool did_test_fail; +extern tflite::ErrorReporter* reporter; +} // namespace micro_test + +#define TF_LITE_MICRO_TESTS_BEGIN \ + namespace micro_test { \ + int tests_passed; \ + int tests_failed; \ + bool is_test_complete; \ + bool did_test_fail; \ + tflite::ErrorReporter* reporter; \ + } \ + \ + int main(int argc, char** argv) { \ + micro_test::tests_passed = 0; \ + micro_test::tests_failed = 0; \ + tflite::MicroErrorReporter error_reporter; \ + micro_test::reporter = &error_reporter; + +#define TF_LITE_MICRO_TESTS_END \ + micro_test::reporter->Report( \ + "%d/%d tests passed", micro_test::tests_passed, \ + (micro_test::tests_failed + micro_test::tests_passed)); \ + if (micro_test::tests_failed == 0) { \ + micro_test::reporter->Report("~~~ALL TESTS PASSED~~~\n"); \ + } else { \ + micro_test::reporter->Report("~~~SOME TESTS FAILED~~~\n"); \ + } \ + } + +// TODO(petewarden): I'm going to hell for what I'm doing to this poor for loop. +#define TF_LITE_MICRO_TEST(name) \ + micro_test::reporter->Report("Testing %s", #name); \ + for (micro_test::is_test_complete = false, \ + micro_test::did_test_fail = false; \ + !micro_test::is_test_complete; micro_test::is_test_complete = true, \ + micro_test::tests_passed += (micro_test::did_test_fail) ? 0 : 1, \ + micro_test::tests_failed += (micro_test::did_test_fail) ? 1 : 0) + +#define TF_LITE_MICRO_EXPECT(x) \ + do { \ + if (!(x)) { \ + micro_test::reporter->Report(#x " failed at %s:%d", __FILE__, __LINE__); \ + micro_test::did_test_fail = true; \ + } \ + } while (false) + +#define TF_LITE_MICRO_EXPECT_EQ(x, y) \ + do { \ + if ((x) != (y)) { \ + micro_test::reporter->Report(#x " == " #y " failed at %s:%d", __FILE__, \ + __LINE__); \ + micro_test::did_test_fail = true; \ + } \ + } while (false) + +#define TF_LITE_MICRO_EXPECT_NE(x, y) \ + do { \ + if ((x) == (y)) { \ + micro_test::reporter->Report(#x " != " #y " failed at %s:%d", __FILE__, \ + __LINE__); \ + micro_test::did_test_fail = true; \ + } \ + } while (false) + +#define TF_LITE_MICRO_EXPECT_NEAR(x, y, epsilon) \ + do { \ + auto delta = ((x) > (y)) ? ((x) - (y)) : ((y) - (x)); \ + if (delta > epsilon) { \ + micro_test::reporter->Report(#x " near " #y " failed at %s:%d", \ + __FILE__, __LINE__); \ + micro_test::did_test_fail = true; \ + } \ + } while (false) + +#define TF_LITE_MICRO_EXPECT_GT(x, y) \ + do { \ + if ((x) <= (y)) { \ + micro_test::reporter->Report(#x " > " #y " failed at %s:%d", __FILE__, \ + __LINE__); \ + micro_test::did_test_fail = true; \ + } \ + } while (false) + +#define TF_LITE_MICRO_EXPECT_LT(x, y) \ + do { \ + if ((x) >= (y)) { \ + micro_test::reporter->Report(#x " < " #y " failed at %s:%d", __FILE__, \ + __LINE__); \ + micro_test::did_test_fail = true; \ + } \ + } while (false) + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_TESTING_MICRO_TEST_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh b/tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh new file mode 100755 index 0000000000000000000000000000000000000000..a470dc52f8d84006f55676201caf8af3ff07d0b7 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh @@ -0,0 +1,56 @@ +#!/bin/bash -e +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Tests a 'bluepill' STM32F103 ELF by parsing the log output of Renode emulation. +# +# First argument is the ELF location. +# Second argument is a regular expression that's required to be in the output logs +# for the test to pass. + +declare -r ROOT_DIR=`pwd` +declare -r TEST_TMPDIR=/tmp/test_bluepill_binary/ +declare -r MICRO_LOG_PATH=${TEST_TMPDIR} +declare -r MICRO_LOG_FILENAME=${MICRO_LOG_PATH}/logs.txt +mkdir -p ${MICRO_LOG_PATH} + +docker build -t renode_bluepill \ + -f ${ROOT_DIR}/tensorflow/contrib/lite/experimental/micro/testing/Dockerfile.bluepill \ + ${ROOT_DIR}/tensorflow/contrib/lite/experimental/micro/testing/ + +exit_code=0 +# running in `if` to avoid setting +e +if ! docker run \ + --log-driver=none -a stdout -a stderr \ + -v ${ROOT_DIR}:/workspace \ + -v /tmp:/tmp \ + -e BIN=/workspace/$1 \ + -e SCRIPT=/workspace/tensorflow/contrib/lite/experimental/micro/testing/bluepill.resc \ + -e EXPECTED="$2" \ + -it renode_bluepill \ + /bin/bash -c "/opt/renode/tests/test.sh /workspace/tensorflow/contrib/lite/experimental/micro/testing/bluepill.robot 2>&1 >${MICRO_LOG_FILENAME}" +then + exit_code=1 +fi + +echo "LOGS:" +cat ${MICRO_LOG_FILENAME} +if [ $exit_code -eq 0 ] +then + echo "$1: PASS" +else + echo "$1: FAIL - '$2' not found in logs." +fi +exit $exit_code diff --git a/tensorflow/contrib/lite/experimental/micro/testing/test_linux_binary.sh b/tensorflow/contrib/lite/experimental/micro/testing/test_linux_binary.sh new file mode 100755 index 0000000000000000000000000000000000000000..24131a6d2df6c0187696b7c21efba2323ef1a305 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/testing/test_linux_binary.sh @@ -0,0 +1,39 @@ +#!/bin/bash -e +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Tests a Linux binary by parsing the log output. +# +# First argument is the binary location. +# Second argument is a regular expression that's required to be in the output logs +# for the test to pass. + +declare -r ROOT_DIR=`pwd` +declare -r TEST_TMPDIR=/tmp/test_bluepill_binary/ +declare -r MICRO_LOG_PATH=${TEST_TMPDIR}/$1 +declare -r MICRO_LOG_FILENAME=${MICRO_LOG_PATH}/logs.txt +mkdir -p ${MICRO_LOG_PATH} + +$1 2>&1 | tee ${MICRO_LOG_FILENAME} + +if grep -q "$2" ${MICRO_LOG_FILENAME} +then + echo "$1: PASS" + exit 0 +else + echo "$1: FAIL - '$2' not found in logs." + exit 1 +fi + diff --git a/tensorflow/contrib/lite/experimental/micro/tools/make/Makefile b/tensorflow/contrib/lite/experimental/micro/tools/make/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..3f749e53ef1aa995247f16cba059c369e27757c9 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/tools/make/Makefile @@ -0,0 +1,168 @@ +MAKEFILE_DIR := tensorflow/contrib/lite/experimental/micro/tools/make + +# Try to figure out the host system +HOST_OS := +ifeq ($(OS),Windows_NT) + HOST_OS = windows +else + UNAME_S := $(shell uname -s) + ifeq ($(UNAME_S),Linux) + HOST_OS := linux + endif + ifeq ($(UNAME_S),Darwin) + HOST_OS := osx + endif +endif + +HOST_ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi) + +# Override these on the make command line to target a specific architecture. For example: +# make -f tensorflow/contrib/lite/Makefile TARGET=rpi TARGET_ARCH=armv7l +TARGET := $(HOST_OS) +TARGET_ARCH := $(HOST_ARCH) + +INCLUDES := \ +-I. \ +-I$(MAKEFILE_DIR)/../../../../../ \ +-I$(MAKEFILE_DIR)/../../../../../../ \ +-I$(MAKEFILE_DIR)/downloads/ \ +-I$(MAKEFILE_DIR)/downloads/gemmlowp \ +-I$(MAKEFILE_DIR)/downloads/flatbuffers/include \ +-I$(OBJDIR) +# This is at the end so any globally-installed frameworks like protobuf don't +# override local versions in the source tree. +INCLUDES += -I/usr/local/include + +TEST_SCRIPT := tensorflow/contrib/lite/experimental/micro/testing/test_linux_binary.sh + +MICROLITE_LIBS := -lm + +# There are no rules for compiling objects for the host system (since we don't +# generate things like the protobuf compiler that require that), so all of +# these settings are for the target compiler. +CXXFLAGS := -O3 -DNDEBUG +CXXFLAGS += --std=c++11 -g -DTF_LITE_STATIC_MEMORY +CCFLAGS := -DNDEBUG -g -DTF_LITE_STATIC_MEMORY +LDOPTS := -L/usr/local/lib +ARFLAGS := -r +TARGET_TOOLCHAIN_PREFIX := +CC_PREFIX := + +# This library is the main target for this makefile. It will contain a minimal +# runtime that can be linked in to other programs. +MICROLITE_LIB_NAME := libtensorflow-microlite.a + +# Test binary for the microcontroller speech model. +MICRO_SPEECH_TEST_SRCS := \ +tensorflow/contrib/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc \ +tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc \ +tensorflow/contrib/lite/experimental/micro/examples/micro_speech/no_features_data.cc \ +tensorflow/contrib/lite/experimental/micro/examples/micro_speech/yes_features_data.cc + +MICROLITE_TEST_SRCS := \ +$(wildcard tensorflow/contrib/lite/experimental/micro/*test.cc) \ +$(wildcard tensorflow/contrib/lite/experimental/micro/kernels/*test.cc) + +MICROLITE_CC_BASE_SRCS := \ +$(wildcard tensorflow/contrib/lite/experimental/micro/*.cc) \ +$(wildcard tensorflow/contrib/lite/experimental/micro/kernels/*.cc) \ +tensorflow/contrib/lite/c/c_api_internal.c \ +tensorflow/contrib/lite/core/api/error_reporter.cc \ +tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc \ +tensorflow/contrib/lite/core/api/op_resolver.cc \ +tensorflow/contrib/lite/kernels/kernel_util.cc \ +tensorflow/contrib/lite/kernels/internal/quantization_util.cc +MICROLITE_CC_SRCS := $(filter-out $(MICROLITE_TEST_SRCS), $(MICROLITE_CC_BASE_SRCS)) + +# These target-specific makefiles should modify or replace options like +# CXXFLAGS or LIBS to work for a specific targetted architecture. All logic +# based on platforms or architectures should happen within these files, to +# keep this main makefile focused on the sources and dependencies. +include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc) + +ALL_SRCS := \ + $(MICRO_SPEECH_TEST_SRCS) \ + $(MICROLITE_CC_SRCS) \ + $(MICROLITE_TEST_SRCS) + +# Where compiled objects are stored. +GENDIR := $(MAKEFILE_DIR)/gen/$(TARGET)_$(TARGET_ARCH)/ +OBJDIR := $(GENDIR)obj/ +BINDIR := $(GENDIR)bin/ +LIBDIR := $(GENDIR)lib/ + +MICROLITE_LIB_PATH := $(LIBDIR)$(MICROLITE_LIB_NAME) + +MICRO_SPEECH_TEST_BINARY := $(BINDIR)micro_speech_test + +CXX := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}g++ +CC := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}gcc +AR := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}ar + +MICRO_SPEECH_TEST_OBJS := $(addprefix $(OBJDIR), \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MICRO_SPEECH_TEST_SRCS)))) + +MICROLITE_LIB_OBJS := $(addprefix $(OBJDIR), \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MICROLITE_CC_SRCS)))) + +MICROLITE_TEST_TARGETS := $(addprefix $(BINDIR), \ +$(patsubst %_test.cc,%.test_target,$(MICROLITE_TEST_SRCS))) + +# For normal manually-created TensorFlow C++ source files. +$(OBJDIR)%.o: %.cc + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@ + +# For normal manually-created TensorFlow C source files. +$(OBJDIR)%.o: %.c + @mkdir -p $(dir $@) + $(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@ + +# The target that's compiled if there's no command-line arguments. +all: $(MICROLITE_LIB_PATH) $(MICRO_SPEECH_TEST_BINARY) + +microlite: $(MICROLITE_LIB_PATH) + +# Hack for generating schema file bypassing flatbuffer parsing +tensorflow/contrib/lite/schema/schema_generated.h: + @cp -u tensorflow/contrib/lite/schema/schema_generated.h.OPENSOURCE tensorflow/contrib/lite/schema/schema_generated.h + +# Gathers together all the objects we've compiled into a single '.a' archive. +$(MICROLITE_LIB_PATH): tensorflow/contrib/lite/schema/schema_generated.h $(MICROLITE_LIB_OBJS) + @mkdir -p $(dir $@) + $(AR) $(ARFLAGS) $(MICROLITE_LIB_PATH) $(MICROLITE_LIB_OBJS) + +$(MICRO_SPEECH_TEST_BINARY): $(MICRO_SPEECH_TEST_OBJS) $(MICROLITE_LIB_PATH) + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) $(INCLUDES) \ + -o $(MICRO_SPEECH_TEST_BINARY) $(MICRO_SPEECH_TEST_OBJS) \ + $(LIBFLAGS) $(MICROLITE_LIB_PATH) $(LDFLAGS) $(MICROLITE_LIBS) + +micro_speech_test: $(MICRO_SPEECH_TEST_BINARY) +micro_speech_test_bin: $(MICRO_SPEECH_TEST_BINARY).bin + +test_micro_speech: $(MICRO_SPEECH_TEST_BINARY) + $(TEST_SCRIPT) $(MICRO_SPEECH_TEST_BINARY) '~~~ALL TESTS PASSED~~~' + +$(BINDIR)%_test : $(OBJDIR)%_test.o $(MICROLITE_LIB_PATH) + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) $(INCLUDES) \ + -o $@ $< \ + $(LIBFLAGS) $(MICROLITE_LIB_PATH) $(LDFLAGS) $(MICROLITE_LIBS) + +$(BINDIR)%.test_target: $(BINDIR)%_test + $(TEST_SCRIPT) $< '~~~ALL TESTS PASSED~~~' + +$(info $(MICROLITE_TEST_TARGETS)) + +test: test_micro_speech $(MICROLITE_TEST_TARGETS) + +# Gets rid of all generated files. +clean: + rm -rf $(MAKEFILE_DIR)/gen + +$(DEPDIR)/%.d: ; +.PRECIOUS: $(DEPDIR)/%.d +.PRECIOUS: $(BINDIR)%_test + +-include $(patsubst %,$(DEPDIR)/%.d,$(basename $(ALL_SRCS))) diff --git a/tensorflow/contrib/lite/experimental/micro/tools/make/download_dependencies.sh b/tensorflow/contrib/lite/experimental/micro/tools/make/download_dependencies.sh new file mode 100755 index 0000000000000000000000000000000000000000..4c2ff8545dbdcc426bf62aaeb07ca22d8b17cc69 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/tools/make/download_dependencies.sh @@ -0,0 +1,73 @@ +#!/bin/bash +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR/../../../../../../.." + +DOWNLOADS_DIR=tensorflow/contrib/lite/experimental/micro/tools/make/downloads +BZL_FILE_PATH=tensorflow/workspace.bzl + +# Ensure it is being run from repo root +if [ ! -f $BZL_FILE_PATH ]; then + echo "Could not find ${BZL_FILE_PATH}": + echo "Likely you are not running this from the root directory of the repository."; + exit 1; +fi + +GEMMLOWP_URL="https://github.com/google/gemmlowp/archive/719139ce755a0f31cbf1c37f7f98adcc7fc9f425.zip" +FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/1f5eae5d6a135ff6811724f6c57f911d1f46bb15.tar.gz" +CMSIS_URL="https://github.com/ARM-software/CMSIS_5/archive/5.4.0.zip" +STM32_BARE_LIB_URL="https://github.com/google/stm32_bare_lib/archive/50e0da307a2821bb54af1f57b969e6b76cb89d32.zip" + +download_and_extract() { + local usage="Usage: download_and_extract URL DIR" + local url="${1:?${usage}}" + local dir="${2:?${usage}}" + echo "downloading ${url}" >&2 + mkdir -p "${dir}" + if [[ "${url}" == *gz ]]; then + curl -Ls "${url}" | tar -C "${dir}" --strip-components=1 -xz + elif [[ "${url}" == *zip ]]; then + tempdir=$(mktemp -d) + tempdir2=$(mktemp -d) + + curl -L ${url} > ${tempdir}/zipped.zip + unzip ${tempdir}/zipped.zip -d ${tempdir2} + + # If the zip file contains nested directories, extract the files from the + # inner directory. + if ls ${tempdir2}/*/* 1> /dev/null 2>&1; then + # unzip has no strip components, so unzip to a temp dir, and move the + # files we want from the tempdir to destination. + cp -R ${tempdir2}/*/* ${dir}/ + else + cp -R ${tempdir2}/* ${dir}/ + fi + rm -rf ${tempdir2} ${tempdir} + fi + + # Delete any potential BUILD files, which would interfere with Bazel builds. + find "${dir}" -type f -name '*BUILD' -delete +} + +download_and_extract "${GEMMLOWP_URL}" "${DOWNLOADS_DIR}/gemmlowp" +download_and_extract "${FLATBUFFERS_URL}" "${DOWNLOADS_DIR}/flatbuffers" +download_and_extract "${CMSIS_URL}" "${DOWNLOADS_DIR}/cmsis" +download_and_extract "${STM32_BARE_LIB_URL}" "${DOWNLOADS_DIR}/stm32_bare_lib" + +echo "download_dependencies.sh completed successfully." >&2 diff --git a/tensorflow/contrib/lite/experimental/micro/tools/make/targets/bluepill_makefile.inc b/tensorflow/contrib/lite/experimental/micro/tools/make/targets/bluepill_makefile.inc new file mode 100644 index 0000000000000000000000000000000000000000..022a8422dc89c048797d0f9ba224f67060d210d7 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/tools/make/targets/bluepill_makefile.inc @@ -0,0 +1,65 @@ +# Settings for Blue Pill platforms. +ifeq ($(TARGET), bluepill) + TARGET_ARCH := cortex-m3 + TARGET_TOOLCHAIN_PREFIX := arm-none-eabi- + + PLATFORM_FLAGS = \ + -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \ + -DTF_LITE_STATIC_MEMORY \ + -DTF_LITE_MCU_DEBUG_LOG \ + -fno-rtti \ + -fmessage-length=0 \ + -fno-exceptions \ + -fno-unwind-tables \ + -fno-builtin \ + -ffunction-sections \ + -fdata-sections \ + -funsigned-char \ + -MMD \ + -mcpu=cortex-m3 \ + -mthumb \ + -std=gnu++11 \ + -Wvla \ + -Wall \ + -Wextra \ + -Wno-unused-parameter \ + -Wno-missing-field-initializers \ + -Wno-write-strings \ + -Wno-sign-compare \ + -fno-delete-null-pointer-checks \ + -fomit-frame-pointer \ + -fpermissive \ + -nostdlib \ + -g \ + -Os + CXXFLAGS += $(PLATFORM_FLAGS) + CCFLAGS += $(PLATFORM_FLAGS) + LDFLAGS += \ + -T $(MAKEFILE_DIR)/downloads/stm32_bare_lib/stm32_linker_layout.lds \ + -Wl,-Map=$(MAKEFILE_DIR)/gen/$(TARGET).map,--cref \ + -Wl,--gc-sections + BUILD_TYPE := micro + MICROLITE_LIBS := \ + -lm + INCLUDES += \ + -isystem$(MAKEFILE_DIR)/downloads/cmsis/CMSIS/Core/Include/ \ + -I$(MAKEFILE_DIR)/downloads/stm32_bare_lib/include + MICROLITE_CC_SRCS += \ + $(wildcard $(MAKEFILE_DIR)/downloads/stm32_bare_lib/source/*.c) \ + $(wildcard $(MAKEFILE_DIR)/downloads/stm32_bare_lib/source/*.cc) + TEST_SCRIPT := tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh + # These are tests that don't currently work on the blue pill. + EXCLUDED_TESTS := \ + tensorflow/contrib/lite/experimental/micro/micro_interpreter_test.cc \ + tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator_test.cc + MICROLITE_TEST_SRCS := $(filter-out $(EXCLUDED_TESTS), $(MICROLITE_TEST_SRCS)) + +# These are microcontroller-specific rules for converting the ELF output +# of the linker into a binary image that can be loaded directly. +OBJCOPY := $(TARGET_TOOLCHAIN_PREFIX)objcopy + +$(BINDIR)/%.bin: $(BINDIR)/% + @mkdir -p $(dir $@) + $(OBJCOPY) $< $@ -O binary + +endif \ No newline at end of file diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/BUILD b/tensorflow/contrib/lite/experimental/microfrontend/lib/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..3fd4b9fe82f7959fd86df7696950a9d3ae205042 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/BUILD @@ -0,0 +1,188 @@ +# Library for generating feature vectors from audio data + +package( + default_visibility = ["//visibility:private"], +) + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "bits", + hdrs = ["bits.h"], +) + +cc_library( + name = "fft", + srcs = [ + "fft.c", + "fft_util.c", + ], + hdrs = [ + "fft.h", + "fft_util.h", + ], + deps = ["@kissfft//:kiss_fftr_16"], +) + +cc_library( + name = "filterbank", + srcs = [ + "filterbank.c", + "filterbank_util.c", + ], + hdrs = [ + "filterbank.h", + "filterbank_util.h", + ], + deps = [ + ":bits", + ":fft", + ], +) + +cc_library( + name = "frontend", + srcs = [ + "frontend.c", + "frontend_util.c", + ], + hdrs = [ + "frontend.h", + "frontend_util.h", + ], + deps = [ + ":bits", + ":fft", + ":filterbank", + ":log_scale", + ":noise_reduction", + ":pcan_gain_control", + ":window", + ], +) + +cc_library( + name = "log_scale", + srcs = [ + "log_lut.c", + "log_scale.c", + "log_scale_util.c", + ], + hdrs = [ + "log_lut.h", + "log_scale.h", + "log_scale_util.h", + ], + deps = [ + ":bits", + ], +) + +cc_library( + name = "noise_reduction", + srcs = [ + "noise_reduction.c", + "noise_reduction_util.c", + ], + hdrs = [ + "noise_reduction.h", + "noise_reduction_util.h", + ], +) + +cc_library( + name = "pcan_gain_control", + srcs = [ + "pcan_gain_control.c", + "pcan_gain_control_util.c", + ], + hdrs = [ + "pcan_gain_control.h", + "pcan_gain_control_util.h", + ], + deps = [ + ":bits", + ], +) + +cc_library( + name = "window", + srcs = [ + "window.c", + "window_util.c", + ], + hdrs = [ + "window.h", + "window_util.h", + ], +) + +cc_test( + name = "fft_test", + size = "small", + srcs = ["fft_test.cc"], + deps = [ + ":fft", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "filterbank_test", + size = "small", + srcs = ["filterbank_test.cc"], + deps = [ + ":filterbank", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "frontend_test", + size = "small", + srcs = ["frontend_test.cc"], + deps = [ + ":frontend", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "log_scale_test", + size = "small", + srcs = ["log_scale_test.cc"], + deps = [ + ":log_scale", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "noise_reduction_test", + size = "small", + srcs = ["noise_reduction_test.cc"], + deps = [ + ":noise_reduction", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "pcan_gain_control_test", + size = "small", + srcs = ["pcan_gain_control_test.cc"], + deps = [ + ":pcan_gain_control", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "window_test", + size = "small", + srcs = ["window_test.cc"], + deps = [ + ":window", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/README b/tensorflow/contrib/lite/experimental/microfrontend/lib/README new file mode 100644 index 0000000000000000000000000000000000000000..731d88c5bdaafe225d847a619192ba16cec7c25f --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/README @@ -0,0 +1,9 @@ +The binary frontend_main shows sample usage of the frontend, printing out +coefficients when it has processed enough data. + +The binary frontend_memmap_main shows a sample usage of how to avoid all the +init code in your runtime, by first running "frontend_generate_memmap" to +create a header/source file that uses a baked in frontend state. This command +could be automated as part of your build process, or you can just use the output +directly. + diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/bits.h b/tensorflow/contrib/lite/experimental/microfrontend/lib/bits.h new file mode 100644 index 0000000000000000000000000000000000000000..f81bc2b023e62a377be9e5ba094a520cdc7b358f --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/bits.h @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_BITS_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_BITS_H_ + +#ifdef __cplusplus +#include + +extern "C" { +#endif + +static inline int CountLeadingZeros32Slow(uint64_t n) { + int zeroes = 28; + if (n >> 16) zeroes -= 16, n >>= 16; + if (n >> 8) zeroes -= 8, n >>= 8; + if (n >> 4) zeroes -= 4, n >>= 4; + return "\4\3\2\2\1\1\1\1\0\0\0\0\0\0\0"[n] + zeroes; +} + +static inline int CountLeadingZeros32(uint32_t n) { +#if defined(_MSC_VER) + unsigned long result = 0; // NOLINT(runtime/int) + if (_BitScanReverse(&result, n)) { + return 31 - result; + } + return 32; +#elif defined(__GNUC__) + + // Handle 0 as a special case because __builtin_clz(0) is undefined. + if (n == 0) { + return 32; + } + return __builtin_clz(n); +#else + return CountLeadingZeros32Slow(n); +#endif +} + +static inline int MostSignificantBit32(uint32_t n) { + return 32 - CountLeadingZeros32(n); +} + +static inline int CountLeadingZeros64Slow(uint64_t n) { + int zeroes = 60; + if (n >> 32) zeroes -= 32, n >>= 32; + if (n >> 16) zeroes -= 16, n >>= 16; + if (n >> 8) zeroes -= 8, n >>= 8; + if (n >> 4) zeroes -= 4, n >>= 4; + return "\4\3\2\2\1\1\1\1\0\0\0\0\0\0\0"[n] + zeroes; +} + +static inline int CountLeadingZeros64(uint64_t n) { +#if defined(_MSC_VER) && defined(_M_X64) + // MSVC does not have __buitin_clzll. Use _BitScanReverse64. + unsigned long result = 0; // NOLINT(runtime/int) + if (_BitScanReverse64(&result, n)) { + return 63 - result; + } + return 64; +#elif defined(_MSC_VER) + // MSVC does not have __buitin_clzll. Compose two calls to _BitScanReverse + unsigned long result = 0; // NOLINT(runtime/int) + if ((n >> 32) && _BitScanReverse(&result, n >> 32)) { + return 31 - result; + } + if (_BitScanReverse(&result, n)) { + return 63 - result; + } + return 64; +#elif defined(__GNUC__) + + // Handle 0 as a special case because __builtin_clzll(0) is undefined. + if (n == 0) { + return 64; + } + return __builtin_clzll(n); +#else + return CountLeadingZeros64Slow(n); +#endif +} + +static inline int MostSignificantBit64(uint64_t n) { + return 64 - CountLeadingZeros64(n); +} + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_BITS_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/fft.c b/tensorflow/contrib/lite/experimental/microfrontend/lib/fft.c new file mode 100644 index 0000000000000000000000000000000000000000..1ecbb30b514294ac3b0006d28064b193806fe1f3 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/fft.c @@ -0,0 +1,54 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/fft.h" + +#include + +#define FIXED_POINT 16 +#include "kiss_fft.h" +// Internal test dependency placeholder1 +// Internal test dependency placeholder2 +#include "tools/kiss_fftr.h" +// Internal test dependency placeholder3 + +void FftCompute(struct FftState* state, const int16_t* input, + int input_scale_shift) { + const size_t input_size = state->input_size; + const size_t fft_size = state->fft_size; + + int16_t* fft_input = state->input; + // First, scale the input by the given shift. + int i; + for (i = 0; i < input_size; ++i) { + *fft_input++ = (*input++) << input_scale_shift; + } + // Zero out whatever else remains in the top part of the input. + for (; i < fft_size; ++i) { + *fft_input++ = 0; + } + + // Apply the FFT. + kiss_fftr((const kiss_fftr_cfg)state->scratch, state->input, + (kiss_fft_cpx*)state->output); +} + +void FftInit(struct FftState* state) { + // All the initialization is done in FftPopulateState() +} + +void FftReset(struct FftState* state) { + memset(state->input, 0, state->fft_size * sizeof(*state->input)); + memset(state->output, 0, (state->fft_size / 2 + 1) * sizeof(*state->output)); +} diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/fft.h b/tensorflow/contrib/lite/experimental/microfrontend/lib/fft.h new file mode 100644 index 0000000000000000000000000000000000000000..e7644bf2a70f5185db8c7a9356e1ee145ff14bd2 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/fft.h @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +struct complex_int16_t { + int16_t real; + int16_t imag; +}; + +struct FftState { + int16_t* input; + struct complex_int16_t* output; + size_t fft_size; + size_t input_size; + void* scratch; + size_t scratch_size; +}; + +void FftCompute(struct FftState* state, const int16_t* input, + int input_scale_shift); + +void FftInit(struct FftState* state); + +void FftReset(struct FftState* state); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/fft_io.c b/tensorflow/contrib/lite/experimental/microfrontend/lib/fft_io.c new file mode 100644 index 0000000000000000000000000000000000000000..cc1ce209d8501b3d41199c83362a7576031d5005 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/fft_io.c @@ -0,0 +1,33 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/fft_io.h" + +void FftWriteMemmapPreamble(FILE* fp, const struct FftState* state) { + fprintf(fp, "static int16_t fft_input[%zu];\n", state->fft_size); + fprintf(fp, "static struct complex_int16_t fft_output[%zu];\n", + state->fft_size / 2 + 1); + fprintf(fp, "static char fft_scratch[%zu];\n", state->scratch_size); + fprintf(fp, "\n"); +} + +void FftWriteMemmap(FILE* fp, const struct FftState* state, + const char* variable) { + fprintf(fp, "%s->input = fft_input;\n", variable); + fprintf(fp, "%s->output = fft_output;\n", variable); + fprintf(fp, "%s->fft_size = %zu;\n", variable, state->fft_size); + fprintf(fp, "%s->input_size = %zu;\n", variable, state->input_size); + fprintf(fp, "%s->scratch = fft_scratch;\n", variable); + fprintf(fp, "%s->scratch_size = %zu;\n", variable, state->scratch_size); +} diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.h b/tensorflow/contrib/lite/experimental/microfrontend/lib/fft_io.h similarity index 54% rename from tensorflow/compiler/xla/service/gpu/gpu_options.h rename to tensorflow/contrib/lite/experimental/microfrontend/lib/fft_io.h index 498d4a94955cb2c50e0b165f28ded44ac1c0bfff..4d10c3a92af7e86d572f1775e9958d29d12bad2e 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_options.h +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/fft_io.h @@ -12,22 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_IO_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_IO_H_ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ +#include -#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/fft.h" -// Helper functions for querying options that are specific to the GPU backend. +#ifdef __cplusplus +extern "C" { +#endif -namespace xla { -namespace gpu { +void FftWriteMemmapPreamble(FILE* fp, const struct FftState* state); +void FftWriteMemmap(FILE* fp, const struct FftState* state, + const char* variable); -// Returns true if we should use heuristics to assign convolution layouts, as -// opposed to always assigning NCHW. -bool ConvUseLayoutHeuristic(const HloModuleConfig& config); +#ifdef __cplusplus +} // extern "C" +#endif -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_IO_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/fft_test.cc b/tensorflow/contrib/lite/experimental/microfrontend/lib/fft_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b8684a0b5c0187afb28fac6f11c8231b75abc1d4 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/fft_test.cc @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/fft.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/fft_util.h" + +#include +#include + +namespace { + +const int16_t kFakeWindow[] = { + 0, 1151, 0, -5944, 0, 13311, 0, -21448, 0, 28327, 0, -32256, 0, 32255, + 0, -28328, 0, 21447, 0, -13312, 0, 5943, 0, -1152, 0}; +const int kScaleShift = 0; + +TEST(FftTest, CheckOutputValues) { + struct FftState state; + ASSERT_TRUE( + FftPopulateState(&state, sizeof(kFakeWindow) / sizeof(kFakeWindow[0]))); + + FftInit(&state); + FftCompute(&state, kFakeWindow, kScaleShift); + + const struct complex_int16_t expected[] = { + {0, 0}, {-10, 9}, {-20, 0}, {-9, -10}, {0, 25}, {-119, 119}, + {-887, 0}, {3000, 3000}, {0, -6401}, {-3000, 3000}, {886, 0}, {118, 119}, + {0, 25}, {9, -10}, {19, 0}, {9, 9}, {0, 0}}; + ASSERT_EQ(state.fft_size / 2 + 1, sizeof(expected) / sizeof(expected[0])); + for (int i = 0; i <= state.fft_size / 2; ++i) { + EXPECT_EQ(state.output[i].real, expected[i].real); + EXPECT_EQ(state.output[i].imag, expected[i].imag); + } + + FftFreeStateContents(&state); +} + +} // namespace diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/fft_util.c b/tensorflow/contrib/lite/experimental/microfrontend/lib/fft_util.c new file mode 100644 index 0000000000000000000000000000000000000000..55494422f375e111e0e6c5c47f7cf193a364006a --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/fft_util.c @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/fft_util.h" + +#include + +#define FIXED_POINT 16 +#include "kiss_fft.h" +#include "tools/kiss_fftr.h" + +int FftPopulateState(struct FftState* state, size_t input_size) { + state->input_size = input_size; + state->fft_size = 1; + while (state->fft_size < state->input_size) { + state->fft_size <<= 1; + } + + state->input = malloc(state->fft_size * sizeof(*state->input)); + if (state->input == NULL) { + fprintf(stderr, "Failed to alloc fft input buffer\n"); + return 0; + } + + state->output = + malloc((state->fft_size / 2 + 1) * sizeof(*state->output) * 2); + if (state->output == NULL) { + fprintf(stderr, "Failed to alloc fft output buffer\n"); + return 0; + } + + // Ask kissfft how much memory it wants. + size_t scratch_size = 0; + kiss_fftr_cfg kfft_cfg = + kiss_fftr_alloc(state->fft_size, 0, NULL, &scratch_size); + if (kfft_cfg != NULL) { + fprintf(stderr, "Kiss memory sizing failed.\n"); + return 0; + } + state->scratch = malloc(scratch_size); + if (state->scratch == NULL) { + fprintf(stderr, "Failed to alloc fft scratch buffer\n"); + return 0; + } + state->scratch_size = scratch_size; + // Let kissfft configure the scratch space we just allocated + kfft_cfg = kiss_fftr_alloc(state->fft_size, 0, state->scratch, &scratch_size); + if (kfft_cfg != state->scratch) { + fprintf(stderr, "Kiss memory preallocation strategy failed.\n"); + return 0; + } + return 1; +} + +void FftFreeStateContents(struct FftState* state) { + free(state->input); + free(state->output); + free(state->scratch); +} + diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/fft_util.h b/tensorflow/contrib/lite/experimental/microfrontend/lib/fft_util.h new file mode 100644 index 0000000000000000000000000000000000000000..4935e87fc1ab8b7d41063793b43d10b9402badd9 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/fft_util.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_UTIL_H_ + +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/fft.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// Prepares and FFT for the given input size. +int FftPopulateState(struct FftState* state, size_t input_size); + +// Frees any allocated buffers. +void FftFreeStateContents(struct FftState* state); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_UTIL_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank.c b/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank.c new file mode 100644 index 0000000000000000000000000000000000000000..944eb1a7379746dbe94715838e12e6fcf9f3a5df --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank.c @@ -0,0 +1,134 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank.h" + +#include + +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/bits.h" + +void FilterbankConvertFftComplexToEnergy(struct FilterbankState* state, + struct complex_int16_t* fft_output, + int32_t* energy) { + const int end_index = state->end_index; + int i; + energy += state->start_index; + fft_output += state->start_index; + for (i = state->start_index; i < end_index; ++i) { + const int32_t real = fft_output->real; + const int32_t imag = fft_output->imag; + fft_output++; + const uint32_t mag_squared = (real * real) + (imag * imag); + *energy++ = mag_squared; + } +} + +void FilterbankAccumulateChannels(struct FilterbankState* state, + const int32_t* energy) { + uint64_t* work = state->work; + uint64_t weight_accumulator = 0; + uint64_t unweight_accumulator = 0; + + const int16_t* channel_frequency_starts = state->channel_frequency_starts; + const int16_t* channel_weight_starts = state->channel_weight_starts; + const int16_t* channel_widths = state->channel_widths; + + int num_channels_plus_1 = state->num_channels + 1; + int i; + for (i = 0; i < num_channels_plus_1; ++i) { + const int32_t* magnitudes = energy + *channel_frequency_starts++; + const int16_t* weights = state->weights + *channel_weight_starts; + const int16_t* unweights = state->unweights + *channel_weight_starts++; + const int width = *channel_widths++; + int j; + for (j = 0; j < width; ++j) { + weight_accumulator += *weights++ * ((uint64_t) *magnitudes); + unweight_accumulator += *unweights++ * ((uint64_t) *magnitudes); + ++magnitudes; + } + *work++ = weight_accumulator; + weight_accumulator = unweight_accumulator; + unweight_accumulator = 0; + } +} + +static uint16_t Sqrt32(uint32_t num) { + if (num == 0) { + return 0; + } + uint32_t res = 0; + int max_bit_number = 32 - MostSignificantBit32(num); + max_bit_number |= 1; + uint32_t bit = 1U << (31 - max_bit_number); + int iterations = (31 - max_bit_number) / 2 + 1; + while (iterations--) { + if (num >= res + bit) { + num -= res + bit; + res = (res >> 1U) + bit; + } else { + res >>= 1U; + } + bit >>= 2U; + } + // Do rounding - if we have the bits. + if (num > res && res != 0xFFFF) { + ++res; + } + return res; +} + +static uint32_t Sqrt64(uint64_t num) { + // Take a shortcut and just use 32 bit operations if the upper word is all + // clear. This will cause a slight off by one issue for numbers close to 2^32, + // but it probably isn't going to matter (and gives us a big performance win). + if ((num >> 32) == 0) { + return Sqrt32((uint32_t) num); + } + uint64_t res = 0; + int max_bit_number = 64 - MostSignificantBit64(num); + max_bit_number |= 1; + uint64_t bit = 1ULL << (63 - max_bit_number); + int iterations = (63 - max_bit_number) / 2 + 1; + while (iterations--) { + if (num >= res + bit) { + num -= res + bit; + res = (res >> 1U) + bit; + } else { + res >>= 1U; + } + bit >>= 2U; + } + // Do rounding - if we have the bits. + if (num > res && res != 0xFFFFFFFFLL) { + ++res; + } + return res; +} + +uint32_t* FilterbankSqrt(struct FilterbankState* state, int scale_down_shift) { + const int num_channels = state->num_channels; + const int64_t* work = state->work + 1; + // Reuse the work buffer since we're fine clobbering it at this point to hold + // the output. + uint32_t* output = (uint32_t*) state->work; + int i; + for (i = 0; i < num_channels; ++i) { + *output++ = Sqrt64(*work++) >> scale_down_shift; + } + return (uint32_t*) state->work; +} + +void FilterbankReset(struct FilterbankState* state) { + memset(state->work, 0, (state->num_channels + 1) * sizeof(*state->work)); +} diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank.h b/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank.h new file mode 100644 index 0000000000000000000000000000000000000000..0dd9c3fa6516808527ab0d1dd5a4b7bee1449f28 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank.h @@ -0,0 +1,63 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_H_ + +#include +#include + +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/fft.h" + +#define kFilterbankBits 12 + +#ifdef __cplusplus +extern "C" { +#endif + +struct FilterbankState { + int num_channels; + int start_index; + int end_index; + int16_t* channel_frequency_starts; + int16_t* channel_weight_starts; + int16_t* channel_widths; + int16_t* weights; + int16_t* unweights; + uint64_t* work; +}; + +// Converts the relevant complex values of an FFT output into energy (the +// square magnitude). +void FilterbankConvertFftComplexToEnergy(struct FilterbankState* state, + struct complex_int16_t* fft_output, + int32_t* energy); + +// Computes the mel-scale filterbank on the given energy array. Output is cached +// internally - to fetch it, you need to call FilterbankSqrt. +void FilterbankAccumulateChannels(struct FilterbankState* state, + const int32_t* energy); + +// Applies an integer square root to the 64 bit intermediate values of the +// filterbank, and returns a pointer to them. Memory will be invalidated the +// next time FilterbankAccumulateChannels is called. +uint32_t* FilterbankSqrt(struct FilterbankState* state, int scale_down_shift); + +void FilterbankReset(struct FilterbankState* state); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_io.c b/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_io.c new file mode 100644 index 0000000000000000000000000000000000000000..672ddd530f847c6bc12643a7d2d77156a75db90f --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_io.c @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_io.h" + +static void PrintArray(FILE* fp, const char* name, const int16_t* values, + size_t size) { + fprintf(fp, "static int16_t filterbank_%s[] = {", name); + for (int i = 0; i < size; ++i) { + fprintf(fp, "%d", values[i]); + if (i < size - 1) { + fprintf(fp, ", "); + } + } + fprintf(fp, "};\n"); +} + +void FilterbankWriteMemmapPreamble(FILE* fp, + const struct FilterbankState* state) { + const int num_channels_plus_1 = state->num_channels + 1; + + PrintArray(fp, "channel_frequency_starts", state->channel_frequency_starts, + num_channels_plus_1); + PrintArray(fp, "channel_weight_starts", state->channel_weight_starts, + num_channels_plus_1); + PrintArray(fp, "channel_widths", state->channel_widths, num_channels_plus_1); + int num_weights = 0; + int i; + for (i = 0; i < num_channels_plus_1; ++i) { + num_weights += state->channel_widths[i]; + } + PrintArray(fp, "weights", state->weights, num_weights); + PrintArray(fp, "unweights", state->unweights, num_weights); + + fprintf(fp, "static uint64_t filterbank_work[%d];\n", num_channels_plus_1); + fprintf(fp, "\n"); +} + +void FilterbankWriteMemmap(FILE* fp, const struct FilterbankState* state, + const char* variable) { + fprintf(fp, "%s->num_channels = %d;\n", variable, state->num_channels); + fprintf(fp, "%s->start_index = %d;\n", variable, state->start_index); + fprintf(fp, "%s->end_index = %d;\n", variable, state->end_index); + + fprintf( + fp, + "%s->channel_frequency_starts = filterbank_channel_frequency_starts;\n", + variable); + fprintf(fp, "%s->channel_weight_starts = filterbank_channel_weight_starts;\n", + variable); + fprintf(fp, "%s->channel_widths = filterbank_channel_widths;\n", variable); + fprintf(fp, "%s->weights = filterbank_weights;\n", variable); + fprintf(fp, "%s->unweights = filterbank_unweights;\n", variable); + fprintf(fp, "%s->work = filterbank_work;\n", variable); +} diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_io.h b/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_io.h new file mode 100644 index 0000000000000000000000000000000000000000..1ddc314df2234de74b3712780a7dff4504017747 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_io.h @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_IO_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_IO_H_ + +#include + +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void FilterbankWriteMemmapPreamble(FILE* fp, + const struct FilterbankState* state); +void FilterbankWriteMemmap(FILE* fp, const struct FilterbankState* state, + const char* variable); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_IO_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_test.cc b/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..88d8de4b8f0ebd632c22f1665c1f5281025a338b --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_test.cc @@ -0,0 +1,194 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_util.h" + +#include + +#include +#include + +namespace { + +const int kSampleRate = 1000; +const int kSpectrumSize = 17; +const int kStartIndex = 1; +const int kEndIndex = 15; +const int32_t kEnergy[] = {-1, 181, 400, 181, 625, 28322, + 786769, 18000000, 40972801, 18000000, 784996, 28085, + 625, 181, 361, -1, -1}; +const uint64_t kWork[] = {1835887, 61162970173, 258694800000}; +const int kScaleShift = 0; + +// Test filterbank generation using scaled-down defaults. +class FilterbankTest : public ::testing::Test { + protected: + FilterbankTest() { + config_.num_channels = 2; + config_.lower_band_limit = 8.0; + config_.upper_band_limit = 450.0; + } + + struct FilterbankConfig config_; +}; + +TEST_F(FilterbankTest, CheckStartIndex) { + struct FilterbankState state; + ASSERT_TRUE( + FilterbankPopulateState(&config_, &state, kSampleRate, kSpectrumSize)); + + EXPECT_EQ(state.start_index, kStartIndex); + + FilterbankFreeStateContents(&state); +} + +TEST_F(FilterbankTest, CheckEndIndex) { + struct FilterbankState state; + ASSERT_TRUE( + FilterbankPopulateState(&config_, &state, kSampleRate, kSpectrumSize)); + + EXPECT_EQ(state.end_index, kEndIndex); + + FilterbankFreeStateContents(&state); +} + +TEST_F(FilterbankTest, CheckChannelFrequencyStarts) { + struct FilterbankState state; + ASSERT_TRUE( + FilterbankPopulateState(&config_, &state, kSampleRate, kSpectrumSize)); + + const int16_t expected[] = {0, 4, 8}; + ASSERT_EQ(state.num_channels + 1, sizeof(expected) / sizeof(expected[0])); + for (int i = 0; i <= state.num_channels; ++i) { + EXPECT_EQ(state.channel_frequency_starts[i], expected[i]); + } + + FilterbankFreeStateContents(&state); +} + +TEST_F(FilterbankTest, CheckChannelWeightStarts) { + struct FilterbankState state; + ASSERT_TRUE( + FilterbankPopulateState(&config_, &state, kSampleRate, kSpectrumSize)); + + const int16_t expected[] = {0, 8, 16}; + ASSERT_EQ(state.num_channels + 1, sizeof(expected) / sizeof(expected[0])); + for (int i = 0; i <= state.num_channels; ++i) { + EXPECT_EQ(state.channel_weight_starts[i], expected[i]); + } + + FilterbankFreeStateContents(&state); +} + +TEST_F(FilterbankTest, CheckChannelWidths) { + struct FilterbankState state; + ASSERT_TRUE( + FilterbankPopulateState(&config_, &state, kSampleRate, kSpectrumSize)); + + const int16_t expected[] = {8, 8, 8}; + ASSERT_EQ(state.num_channels + 1, sizeof(expected) / sizeof(expected[0])); + for (int i = 0; i <= state.num_channels; ++i) { + EXPECT_EQ(state.channel_widths[i], expected[i]); + } + + FilterbankFreeStateContents(&state); +} + +TEST_F(FilterbankTest, CheckWeights) { + struct FilterbankState state; + ASSERT_TRUE( + FilterbankPopulateState(&config_, &state, kSampleRate, kSpectrumSize)); + + const int16_t expected[] = {0, 3277, 2217, 1200, 222, 0, 0, 0, + 0, 3376, 2468, 1591, 744, 0, 0, 0, + 0, 4020, 3226, 2456, 1708, 983, 277, 0}; + ASSERT_EQ(state.channel_weight_starts[state.num_channels] + + state.channel_widths[state.num_channels], + sizeof(expected) / sizeof(expected[0])); + for (int i = 0; i < sizeof(expected) / sizeof(expected[0]); ++i) { + EXPECT_EQ(state.weights[i], expected[i]); + } + + FilterbankFreeStateContents(&state); +} + +TEST_F(FilterbankTest, CheckUnweights) { + struct FilterbankState state; + ASSERT_TRUE( + FilterbankPopulateState(&config_, &state, kSampleRate, kSpectrumSize)); + + const int16_t expected[] = {0, 819, 1879, 2896, 3874, 0, 0, 0, + 0, 720, 1628, 2505, 3352, 0, 0, 0, + 0, 76, 870, 1640, 2388, 3113, 3819, 0}; + ASSERT_EQ(state.channel_weight_starts[state.num_channels] + + state.channel_widths[state.num_channels], + sizeof(expected) / sizeof(expected[0])); + for (int i = 0; i < sizeof(expected) / sizeof(expected[0]); ++i) { + EXPECT_EQ(state.unweights[i], expected[i]); + } + + FilterbankFreeStateContents(&state); +} + +TEST_F(FilterbankTest, CheckConvertFftComplexToEnergy) { + struct FilterbankState state; + state.start_index = kStartIndex; + state.end_index = kEndIndex; + + struct complex_int16_t fake_fft[] = { + {0, 0}, {-10, 9}, {-20, 0}, {-9, -10}, {0, 25}, {-119, 119}, + {-887, 0}, {3000, 3000}, {0, -6401}, {-3000, 3000}, {886, 0}, {118, 119}, + {0, 25}, {9, -10}, {19, 0}, {9, 9}, {0, 0}}; + int32_t* energy = reinterpret_cast(fake_fft); + FilterbankConvertFftComplexToEnergy(&state, fake_fft, energy); + + for (int i = state.start_index; i < state.end_index; ++i) { + EXPECT_EQ(energy[i], kEnergy[i]); + } +} + +TEST_F(FilterbankTest, CheckAccumulateChannels) { + struct FilterbankState state; + ASSERT_TRUE( + FilterbankPopulateState(&config_, &state, kSampleRate, kSpectrumSize)); + + FilterbankAccumulateChannels(&state, kEnergy); + + ASSERT_EQ(state.num_channels + 1, sizeof(kWork) / sizeof(kWork[0])); + for (int i = 0; i <= state.num_channels; ++i) { + EXPECT_EQ(state.work[i], kWork[i]); + } + + FilterbankFreeStateContents(&state); +} + +TEST_F(FilterbankTest, CheckSqrt) { + struct FilterbankState state; + ASSERT_TRUE( + FilterbankPopulateState(&config_, &state, kSampleRate, kSpectrumSize)); + std::memcpy(state.work, kWork, sizeof(kWork)); + + uint32_t* scaled_filterbank = FilterbankSqrt(&state, kScaleShift); + + const uint32_t expected[] = {247311, 508620}; + ASSERT_EQ(state.num_channels, sizeof(expected) / sizeof(expected[0])); + for (int i = 0; i < state.num_channels; ++i) { + EXPECT_EQ(scaled_filterbank[i], expected[i]); + } + + FilterbankFreeStateContents(&state); +} + +} // namespace diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_util.c b/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_util.c new file mode 100644 index 0000000000000000000000000000000000000000..53b5e45073455b01be869100f8d2cd26834ab9bb --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_util.c @@ -0,0 +1,225 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_util.h" + +#include +#include +#include + +#define kFilterbankIndexAlignment 4 +#define kFilterbankChannelBlockSize 4 + +void FilterbankFillConfigWithDefaults(struct FilterbankConfig* config) { + config->num_channels = 32; + config->lower_band_limit = 125.0f; + config->upper_band_limit = 7500.0f; + config->output_scale_shift = 7; +} + +static float FreqToMel(float freq) { + return 1127.0 * log(1.0 + (freq / 700.0)); +} + +static void CalculateCenterFrequencies(const int num_channels, + const float lower_frequency_limit, + const float upper_frequency_limit, + float* center_frequencies) { + assert(lower_frequency_limit >= 0.0f); + assert(upper_frequency_limit > lower_frequency_limit); + + const float mel_low = FreqToMel(lower_frequency_limit); + const float mel_hi = FreqToMel(upper_frequency_limit); + const float mel_span = mel_hi - mel_low; + const float mel_spacing = mel_span / ((float) num_channels); + int i; + for (i = 0; i < num_channels; ++i) { + center_frequencies[i] = mel_low + (mel_spacing * (i + 1)); + } +} + +static void QuantizeFilterbankWeights(const float float_weight, + int16_t* weight, int16_t* unweight) { + *weight = floor(float_weight * (1 << kFilterbankBits) + 0.5); + *unweight = floor((1.0 - float_weight) * (1 << kFilterbankBits) + 0.5); +} + +int FilterbankPopulateState(const struct FilterbankConfig* config, + struct FilterbankState* state, + int sample_rate, int spectrum_size) { + state->num_channels = config->num_channels; + const int num_channels_plus_1 = config->num_channels + 1; + + // How should we align things to index counts given the byte alignment? + const int index_alignment = + (kFilterbankIndexAlignment < sizeof(int16_t) + ? 1 + : kFilterbankIndexAlignment / sizeof(int16_t)); + + state->channel_frequency_starts = + malloc(num_channels_plus_1 * sizeof(*state->channel_frequency_starts)); + state->channel_weight_starts = + malloc(num_channels_plus_1 * sizeof(*state->channel_weight_starts)); + state->channel_widths = + malloc(num_channels_plus_1 * sizeof(*state->channel_widths)); + state->work = malloc(num_channels_plus_1 * sizeof(*state->work)); + + float* center_mel_freqs = + malloc(num_channels_plus_1 * sizeof(*center_mel_freqs)); + int16_t* actual_channel_starts = + malloc(num_channels_plus_1 * sizeof(*actual_channel_starts)); + int16_t* actual_channel_widths = + malloc(num_channels_plus_1 * sizeof(*actual_channel_widths)); + + if (state->channel_frequency_starts == NULL || + state->channel_weight_starts == NULL || + state->channel_widths == NULL || + center_mel_freqs == NULL || + actual_channel_starts == NULL || + actual_channel_widths == NULL) { + free(center_mel_freqs); + free(actual_channel_starts); + free(actual_channel_widths); + fprintf(stderr, "Failed to allocate channel buffers\n"); + return 0; + } + + CalculateCenterFrequencies(num_channels_plus_1, config->lower_band_limit, + config->upper_band_limit, center_mel_freqs); + + // Always exclude DC. + const float hz_per_sbin = 0.5 * sample_rate / ((float) spectrum_size - 1); + state->start_index = 1.5 + config->lower_band_limit / hz_per_sbin; + state->end_index = 0; // Initialized to zero here, but actually set below. + + // For each channel, we need to figure out what frequencies belong to it, and + // how much padding we need to add so that we can efficiently multiply the + // weights and unweights for accumulation. To simplify the multiplication + // logic, all channels will have some multiplication to do (even if there are + // no frequencies that accumulate to that channel) - they will be directed to + // a set of zero weights. + int chan_freq_index_start = state->start_index; + int weight_index_start = 0; + int needs_zeros = 0; + + int chan; + for (chan = 0; chan < num_channels_plus_1; ++chan) { + // Keep jumping frequencies until we overshoot the bound on this channel. + int freq_index = chan_freq_index_start; + while (FreqToMel((freq_index) * hz_per_sbin) <= center_mel_freqs[chan]) { + ++freq_index; + } + + const int width = freq_index - chan_freq_index_start; + actual_channel_starts[chan] = chan_freq_index_start; + actual_channel_widths[chan] = width; + + if (width == 0) { + // This channel doesn't actually get anything from the frequencies, it's + // always zero. We need then to insert some 'zero' weights into the + // output, and just redirect this channel to do a single multiplication at + // this point. For simplicity, the zeros are placed at the beginning of + // the weights arrays, so we have to go and update all the other + // weight_starts to reflect this shift (but only once). + state->channel_frequency_starts[chan] = 0; + state->channel_weight_starts[chan] = 0; + state->channel_widths[chan] = kFilterbankChannelBlockSize; + if (!needs_zeros) { + needs_zeros = 1; + int j; + for (j = 0; j < chan; ++j) { + state->channel_weight_starts[j] += kFilterbankChannelBlockSize; + } + weight_index_start += kFilterbankChannelBlockSize; + } + } else { + // How far back do we need to go to ensure that we have the proper + // alignment? + const int aligned_start = + (chan_freq_index_start / index_alignment) * index_alignment; + const int aligned_width = + (chan_freq_index_start - aligned_start + width); + const int padded_width = + (((aligned_width - 1) / kFilterbankChannelBlockSize) + 1) * + kFilterbankChannelBlockSize; + + state->channel_frequency_starts[chan] = aligned_start; + state->channel_weight_starts[chan] = weight_index_start; + state->channel_widths[chan] = padded_width; + weight_index_start += padded_width; + } + chan_freq_index_start = freq_index; + } + + // Allocate the two arrays to store the weights - weight_index_start contains + // the index of what would be the next set of weights that we would need to + // add, so that's how many weights we need to allocate. + state->weights = calloc(weight_index_start, sizeof(*state->weights)); + state->unweights = calloc(weight_index_start, sizeof(*state->unweights)); + + // If the alloc failed, we also need to nuke the arrays. + if (state->weights == NULL || state->unweights == NULL) { + free(center_mel_freqs); + free(actual_channel_starts); + free(actual_channel_widths); + fprintf(stderr, "Failed to allocate weights or unweights\n"); + return 0; + } + + // Next pass, compute all the weights. Since everything has been memset to + // zero, we only need to fill in the weights that correspond to some frequency + // for a channel. + const float mel_low = FreqToMel(config->lower_band_limit); + for (chan = 0; chan < num_channels_plus_1; ++chan) { + int frequency = actual_channel_starts[chan]; + const int num_frequencies = actual_channel_widths[chan]; + const int frequency_offset = + frequency - state->channel_frequency_starts[chan]; + const int weight_start = state->channel_weight_starts[chan]; + const float denom_val = (chan == 0) ? mel_low : center_mel_freqs[chan - 1]; + + int j; + for (j = 0; j < num_frequencies; ++j, ++frequency) { + const float weight = + (center_mel_freqs[chan] - FreqToMel(frequency * hz_per_sbin)) / + (center_mel_freqs[chan] - denom_val); + + // Make the float into an integer for the weights (and unweights). + const int weight_index = weight_start + frequency_offset + j; + QuantizeFilterbankWeights(weight, state->weights + weight_index, + state->unweights + weight_index); + } + if (frequency > state->end_index) { + state->end_index = frequency; + } + } + + free(center_mel_freqs); + free(actual_channel_starts); + free(actual_channel_widths); + if (state->end_index >= spectrum_size) { + fprintf(stderr, "Filterbank end_index is above spectrum size.\n"); + return 0; + } + return 1; +} + +void FilterbankFreeStateContents(struct FilterbankState* state) { + free(state->channel_frequency_starts); + free(state->channel_weight_starts); + free(state->channel_widths); + free(state->weights); + free(state->unweights); + free(state->work); +} diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_util.h b/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_util.h new file mode 100644 index 0000000000000000000000000000000000000000..9ec9bc930286e3fd6a40a1ac080f4b992503cc2e --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_util.h @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_UTIL_H_ + +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank.h" + +#ifdef __cplusplus +extern "C" { +#endif + +struct FilterbankConfig { + // number of frequency channel buckets for filterbank + int num_channels; + // maximum frequency to include + float upper_band_limit; + // minimum frequency to include + float lower_band_limit; + // unused + int output_scale_shift; +}; + +// Fills the frontendConfig with "sane" defaults. +void FilterbankFillConfigWithDefaults(struct FilterbankConfig* config); + +// Allocates any buffers. +int FilterbankPopulateState(const struct FilterbankConfig* config, + struct FilterbankState* state, int sample_rate, + int spectrum_size); + +// Frees any allocated buffers. +void FilterbankFreeStateContents(struct FilterbankState* state); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_UTIL_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.c b/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.c new file mode 100644 index 0000000000000000000000000000000000000000..de7a60b56fd85a702469cec52ecd1c29b209eae0 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.c @@ -0,0 +1,72 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.h" + +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/bits.h" + +struct FrontendOutput FrontendProcessSamples(struct FrontendState* state, + const int16_t* samples, + size_t num_samples, + size_t* num_samples_read) { + struct FrontendOutput output; + output.values = NULL; + output.size = 0; + + // Try to apply the window - if it fails, return and wait for more data. + if (!WindowProcessSamples(&state->window, samples, num_samples, + num_samples_read)) { + return output; + } + + // Apply the FFT to the window's output (and scale it so that the fixed point + // FFT can have as much resolution as possible). + int input_shift = + 15 - MostSignificantBit32(state->window.max_abs_output_value); + FftCompute(&state->fft, state->window.output, input_shift); + + // We can re-ruse the fft's output buffer to hold the energy. + int32_t* energy = (int32_t*) state->fft.output; + + FilterbankConvertFftComplexToEnergy(&state->filterbank, state->fft.output, + energy); + + FilterbankAccumulateChannels(&state->filterbank, energy); + uint32_t* scaled_filterbank = FilterbankSqrt(&state->filterbank, input_shift); + + // Apply noise reduction. + NoiseReductionApply(&state->noise_reduction, scaled_filterbank); + + if (state->pcan_gain_control.enable_pcan) { + PcanGainControlApply(&state->pcan_gain_control, scaled_filterbank); + } + + // Apply the log and scale. + int correction_bits = + MostSignificantBit32(state->fft.fft_size) - 1 - (kFilterbankBits / 2); + uint16_t* logged_filterbank = + LogScaleApply(&state->log_scale, scaled_filterbank, + state->filterbank.num_channels, correction_bits); + + output.size = state->filterbank.num_channels; + output.values = logged_filterbank; + return output; +} + +void FrontendReset(struct FrontendState* state) { + WindowReset(&state->window); + FftReset(&state->fft); + FilterbankReset(&state->filterbank); + NoiseReductionReset(&state->noise_reduction); +} diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.h b/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.h new file mode 100644 index 0000000000000000000000000000000000000000..71ae81024cb3aee6248de17712d7051665aed97c --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.h @@ -0,0 +1,64 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_H_ + +#include +#include + +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/fft.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/window.h" + +#ifdef __cplusplus +extern "C" { +#endif + +struct FrontendState { + struct WindowState window; + struct FftState fft; + struct FilterbankState filterbank; + struct NoiseReductionState noise_reduction; + struct PcanGainControlState pcan_gain_control; + struct LogScaleState log_scale; +}; + +struct FrontendOutput { + const uint16_t* values; + size_t size; +}; + +// Main entry point to processing frontend samples. Updates num_samples_read to +// contain the number of samples that have been consumed from the input array. +// Returns a struct containing the generated output. If not enough samples were +// added to generate a feature vector, the returned size will be 0 and the +// values pointer will be NULL. Note that the output pointer will be invalidated +// as soon as FrontendProcessSamples is called again, so copy the contents +// elsewhere if you need to use them later. +struct FrontendOutput FrontendProcessSamples(struct FrontendState* state, + const int16_t* samples, + size_t num_samples, + size_t* num_samples_read); + +void FrontendReset(struct FrontendState* state); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_io.c b/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_io.c new file mode 100644 index 0000000000000000000000000000000000000000..40bcf247497dffd0f8008d704a609602ecb0d0fb --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_io.c @@ -0,0 +1,69 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_io.h" + +#include + +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/fft_io.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_io.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_io.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_io.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/window_io.h" + +int WriteFrontendStateMemmap(const char* header, const char* source, + const struct FrontendState* state) { + // Write a header that just has our init function. + FILE* fp = fopen(header, "w"); + if (!fp) { + fprintf(stderr, "Failed to open header '%s' for write\n", header); + return 0; + } + fprintf(fp, "#ifndef FRONTEND_STATE_MEMMAP_H_\n"); + fprintf(fp, "#define FRONTEND_STATE_MEMMAP_H_\n"); + fprintf(fp, "\n"); + fprintf(fp, "#include \"frontend.h\"\n"); + fprintf(fp, "\n"); + fprintf(fp, "struct FrontendState* GetFrontendStateMemmap();\n"); + fprintf(fp, "\n"); + fprintf(fp, "#endif // FRONTEND_STATE_MEMMAP_H_\n"); + fclose(fp); + + // Write out the source file that actually has everything in it. + fp = fopen(source, "w"); + if (!fp) { + fprintf(stderr, "Failed to open source '%s' for write\n", source); + return 0; + } + fprintf(fp, "#include \"%s\"\n", header); + fprintf(fp, "\n"); + WindowWriteMemmapPreamble(fp, &state->window); + FftWriteMemmapPreamble(fp, &state->fft); + FilterbankWriteMemmapPreamble(fp, &state->filterbank); + NoiseReductionWriteMemmapPreamble(fp, &state->noise_reduction); + fprintf(fp, "static struct FrontendState state;\n"); + fprintf(fp, "struct FrontendState* GetFrontendStateMemmap() {\n"); + WindowWriteMemmap(fp, &state->window, " (&state.window)"); + FftWriteMemmap(fp, &state->fft, " (&state.fft)"); + FilterbankWriteMemmap(fp, &state->filterbank, " (&state.filterbank)"); + NoiseReductionWriteMemmap(fp, &state->noise_reduction, + " (&state.noise_reduction)"); + LogScaleWriteMemmap(fp, &state->log_scale, " (&state.log_scale)"); + fprintf(fp, " FftInit(&state.fft);\n"); + fprintf(fp, " FrontendReset(&state);\n"); + fprintf(fp, " return &state;\n"); + fprintf(fp, "}\n"); + fclose(fp); + return 1; +} diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_io.h b/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_io.h new file mode 100644 index 0000000000000000000000000000000000000000..4f45577caeab7aaff9210355d9d4d811dedfdf16 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_io.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_IO_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_IO_H_ + +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int WriteFrontendStateMemmap(const char* header, const char* source, + const struct FrontendState* state); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_IO_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_main.c b/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_main.c new file mode 100644 index 0000000000000000000000000000000000000000..46caebeec9059c9d13f19e4ac30d602a6b47eaba --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_main.c @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_util.h" + +int main(int argc, char** argv) { + struct FrontendConfig frontend_config; + FrontendFillConfigWithDefaults(&frontend_config); + + char* filename = argv[1]; + int sample_rate = 16000; + + struct FrontendState frontend_state; + if (!FrontendPopulateState(&frontend_config, &frontend_state, sample_rate)) { + fprintf(stderr, "Failed to populate frontend state\n"); + FrontendFreeStateContents(&frontend_state); + return 1; + } + + + FILE* fp = fopen(filename, "r"); + if (fp == NULL) { + fprintf(stderr, "Failed to open %s for read\n", filename); + return 1; + } + fseek(fp, 0L, SEEK_END); + size_t audio_file_size = ftell(fp) / sizeof(int16_t); + fseek(fp, 0L, SEEK_SET); + int16_t* audio_data = malloc(audio_file_size * sizeof(int16_t)); + int16_t* original_audio_data = audio_data; + if (audio_file_size != + fread(audio_data, sizeof(int16_t), audio_file_size, fp)) { + fprintf(stderr, "Failed to read in all audio data\n"); + return 1; + } + + while (audio_file_size > 0) { + size_t num_samples_read; + struct FrontendOutput output = FrontendProcessSamples( + &frontend_state, audio_data, audio_file_size, &num_samples_read); + audio_data += num_samples_read; + audio_file_size -= num_samples_read; + + if (output.values != NULL) { + int i; + for (i = 0; i < output.size; ++i) { + printf("%d ", output.values[i]); + } + printf("\n"); + } + } + + FrontendFreeStateContents(&frontend_state); + free(original_audio_data); + return 0; +} diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_memmap_generator.c b/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_memmap_generator.c new file mode 100644 index 0000000000000000000000000000000000000000..a4c59b0cccabb70579a0267b6c95ec0b39908276 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_memmap_generator.c @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_util.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_io.h" + +int main(int argc, char** argv) { + if (argc != 3) { + fprintf(stderr, + "%s requires exactly two parameters - the names of the header and " + "source files to save\n"); + return 1; + } + struct FrontendConfig frontend_config; + FrontendFillConfigWithDefaults(&frontend_config); + + int sample_rate = 16000; + struct FrontendState frontend_state; + if (!FrontendPopulateState(&frontend_config, &frontend_state, sample_rate)) { + fprintf(stderr, "Failed to populate frontend state\n"); + FrontendFreeStateContents(&frontend_state); + return 1; + } + + if (!WriteFrontendStateMemmap(argv[1], argv[2], &frontend_state)) { + fprintf(stderr, "Failed to write memmap\n"); + FrontendFreeStateContents(&frontend_state); + return 1; + } + + FrontendFreeStateContents(&frontend_state); + return 0; +} diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_memmap_main.c b/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_memmap_main.c new file mode 100644 index 0000000000000000000000000000000000000000..a4264922b94b5af21f6a2444d55bf10607af4a23 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_memmap_main.c @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.h" +#include "memmap.h" + +int main(int argc, char** argv) { + struct FrontendState* frontend_state = GetFrontendStateMemmap(); + + char* filename = argv[1]; + FILE* fp = fopen(filename, "r"); + if (fp == NULL) { + fprintf(stderr, "Failed to open %s for read\n", filename); + return 1; + } + fseek(fp, 0L, SEEK_END); + size_t audio_file_size = ftell(fp) / sizeof(int16_t); + fseek(fp, 0L, SEEK_SET); + int16_t* audio_data = malloc(audio_file_size * sizeof(int16_t)); + int16_t* original_audio_data = audio_data; + if (audio_file_size != + fread(audio_data, sizeof(int16_t), audio_file_size, fp)) { + fprintf(stderr, "Failed to read in all audio data\n"); + return 1; + } + + while (audio_file_size > 0) { + size_t num_samples_read; + struct FrontendOutput output = FrontendProcessSamples( + frontend_state, audio_data, audio_file_size, &num_samples_read); + audio_data += num_samples_read; + audio_file_size -= num_samples_read; + + if (output.values != NULL) { + int i; + for (i = 0; i < output.size; ++i) { + printf("%d ", output.values[i]); + } + printf("\n"); + } + } + + free(original_audio_data); + return 0; +} diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_test.cc b/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f06e2565c285e62c6921121ee94cc911f0f52fc7 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_test.cc @@ -0,0 +1,120 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_util.h" + +#include +#include + +namespace { + +const int kSampleRate = 1000; +const int kWindowSamples = 25; +const int kStepSamples = 10; +const int16_t kFakeAudioData[] = { + 0, 32767, 0, -32768, 0, 32767, 0, -32768, 0, 32767, 0, -32768, + 0, 32767, 0, -32768, 0, 32767, 0, -32768, 0, 32767, 0, -32768, + 0, 32767, 0, -32768, 0, 32767, 0, -32768, 0, 32767, 0, -32768}; + +// Test end-to-end frontend behaviors. +class FrontendTest : public ::testing::Test { + protected: + FrontendTest() { + config_.window.size_ms = 25; + config_.window.step_size_ms = 10; + config_.noise_reduction.smoothing_bits = 10; + config_.filterbank.num_channels = 2; + config_.filterbank.lower_band_limit = 8.0; + config_.filterbank.upper_band_limit = 450.0; + config_.noise_reduction.smoothing_bits = 10; + config_.noise_reduction.even_smoothing = 0.025; + config_.noise_reduction.odd_smoothing = 0.06; + config_.noise_reduction.min_signal_remaining = 0.05; + config_.pcan_gain_control.enable_pcan = true; + config_.pcan_gain_control.strength = 0.95; + config_.pcan_gain_control.offset = 80.0; + config_.pcan_gain_control.gain_bits = 21; + config_.log_scale.enable_log = true; + config_.log_scale.scale_shift = 6; + } + + struct FrontendConfig config_; +}; + +TEST_F(FrontendTest, CheckOutputValues) { + struct FrontendState state; + ASSERT_TRUE(FrontendPopulateState(&config_, &state, kSampleRate)); + size_t num_samples_read; + + struct FrontendOutput output = FrontendProcessSamples( + &state, kFakeAudioData, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]), &num_samples_read); + + const uint16_t expected[] = {479, 425}; + ASSERT_EQ(output.size, sizeof(expected) / sizeof(expected[0])); + for (int i = 0; i < output.size; ++i) { + EXPECT_EQ(output.values[i], expected[i]); + } + + FrontendFreeStateContents(&state); +} + +TEST_F(FrontendTest, CheckConsecutiveWindow) { + struct FrontendState state; + ASSERT_TRUE(FrontendPopulateState(&config_, &state, kSampleRate)); + size_t num_samples_read; + + FrontendProcessSamples(&state, kFakeAudioData, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]), + &num_samples_read); + struct FrontendOutput output = FrontendProcessSamples( + &state, kFakeAudioData + kWindowSamples, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]) - kWindowSamples, + &num_samples_read); + + const int16_t expected[] = {436, 378}; + ASSERT_EQ(output.size, sizeof(expected) / sizeof(expected[0])); + for (int i = 0; i < output.size; ++i) { + EXPECT_EQ(output.values[i], expected[i]); + } + + FrontendFreeStateContents(&state); +} + +TEST_F(FrontendTest, CheckNotEnoughSamples) { + struct FrontendState state; + ASSERT_TRUE(FrontendPopulateState(&config_, &state, kSampleRate)); + size_t num_samples_read; + + FrontendProcessSamples(&state, kFakeAudioData, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]), + &num_samples_read); + FrontendProcessSamples( + &state, kFakeAudioData + kWindowSamples, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]) - kWindowSamples, + &num_samples_read); + struct FrontendOutput output = FrontendProcessSamples( + &state, kFakeAudioData + kWindowSamples + kStepSamples, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]) - kWindowSamples - + kStepSamples, + &num_samples_read); + + EXPECT_EQ(output.size, 0); + EXPECT_EQ(output.values, nullptr); + + FrontendFreeStateContents(&state); +} + +} // namespace diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_util.c b/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_util.c new file mode 100644 index 0000000000000000000000000000000000000000..ae2d9ae6c4c8bb23a2d26e8f8c48d6b4788217e5 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_util.c @@ -0,0 +1,87 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_util.h" + +#include +#include + +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/bits.h" + +void FrontendFillConfigWithDefaults(struct FrontendConfig* config) { + WindowFillConfigWithDefaults(&config->window); + FilterbankFillConfigWithDefaults(&config->filterbank); + NoiseReductionFillConfigWithDefaults(&config->noise_reduction); + PcanGainControlFillConfigWithDefaults(&config->pcan_gain_control); + LogScaleFillConfigWithDefaults(&config->log_scale); +} + +int FrontendPopulateState(const struct FrontendConfig* config, + struct FrontendState* state, int sample_rate) { + memset(state, 0, sizeof(*state)); + + if (!WindowPopulateState(&config->window, &state->window, sample_rate)) { + fprintf(stderr, "Failed to populate window state\n"); + return 0; + } + + if (!FftPopulateState(&state->fft, state->window.size)) { + fprintf(stderr, "Failed to populate fft state\n"); + return 0; + } + FftInit(&state->fft); + + if (!FilterbankPopulateState(&config->filterbank, &state->filterbank, + sample_rate, state->fft.fft_size / 2 + 1)) { + fprintf(stderr, "Failed to populate filterbank state\n"); + return 0; + } + + if (!NoiseReductionPopulateState(&config->noise_reduction, + &state->noise_reduction, + state->filterbank.num_channels)) { + fprintf(stderr, "Failed to populate noise reduction state\n"); + return 0; + } + + int input_correction_bits = + MostSignificantBit32(state->fft.fft_size) - 1 - (kFilterbankBits / 2); + if (!PcanGainControlPopulateState(&config->pcan_gain_control, + &state->pcan_gain_control, + state->noise_reduction.estimate, + state->filterbank.num_channels, + state->noise_reduction.smoothing_bits, + input_correction_bits)) { + fprintf(stderr, "Failed to populate pcan gain control state\n"); + return 0; + } + + if (!LogScalePopulateState(&config->log_scale, &state->log_scale)) { + fprintf(stderr, "Failed to populate log scale state\n"); + return 0; + } + + FrontendReset(state); + + // All good, return a true value. + return 1; +} + +void FrontendFreeStateContents(struct FrontendState* state) { + WindowFreeStateContents(&state->window); + FftFreeStateContents(&state->fft); + FilterbankFreeStateContents(&state->filterbank); + NoiseReductionFreeStateContents(&state->noise_reduction); + PcanGainControlFreeStateContents(&state->pcan_gain_control); +} diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_util.h b/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_util.h new file mode 100644 index 0000000000000000000000000000000000000000..a958b610eae689192aa96eacefe654292350fcd7 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_util.h @@ -0,0 +1,52 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_UTIL_H_ + +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/fft_util.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_util.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_util.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_util.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control_util.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/window_util.h" + +#ifdef __cplusplus +extern "C" { +#endif + +struct FrontendConfig { + struct WindowConfig window; + struct FilterbankConfig filterbank; + struct NoiseReductionConfig noise_reduction; + struct PcanGainControlConfig pcan_gain_control; + struct LogScaleConfig log_scale; +}; + +// Fills the frontendConfig with "sane" defaults. +void FrontendFillConfigWithDefaults(struct FrontendConfig* config); + +// Allocates any buffers. +int FrontendPopulateState(const struct FrontendConfig* config, + struct FrontendState* state, int sample_rate); + +// Frees any allocated buffers. +void FrontendFreeStateContents(struct FrontendState* state); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_UTIL_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_lut.c b/tensorflow/contrib/lite/experimental/microfrontend/lib/log_lut.c new file mode 100644 index 0000000000000000000000000000000000000000..f8d32102336d19650703e740162e32eef6a6d287 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/log_lut.c @@ -0,0 +1,30 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/log_lut.h" +const uint16_t kLogLut[] +#ifndef _MSC_VER + __attribute__((aligned(4))) +#endif // _MSV_VER + = {0, 224, 442, 654, 861, 1063, 1259, 1450, 1636, 1817, 1992, 2163, + 2329, 2490, 2646, 2797, 2944, 3087, 3224, 3358, 3487, 3611, 3732, 3848, + 3960, 4068, 4172, 4272, 4368, 4460, 4549, 4633, 4714, 4791, 4864, 4934, + 5001, 5063, 5123, 5178, 5231, 5280, 5326, 5368, 5408, 5444, 5477, 5507, + 5533, 5557, 5578, 5595, 5610, 5622, 5631, 5637, 5640, 5641, 5638, 5633, + 5626, 5615, 5602, 5586, 5568, 5547, 5524, 5498, 5470, 5439, 5406, 5370, + 5332, 5291, 5249, 5203, 5156, 5106, 5054, 5000, 4944, 4885, 4825, 4762, + 4697, 4630, 4561, 4490, 4416, 4341, 4264, 4184, 4103, 4020, 3935, 3848, + 3759, 3668, 3575, 3481, 3384, 3286, 3186, 3084, 2981, 2875, 2768, 2659, + 2549, 2437, 2323, 2207, 2090, 1971, 1851, 1729, 1605, 1480, 1353, 1224, + 1094, 963, 830, 695, 559, 421, 282, 142, 0, 0}; diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_lut.h b/tensorflow/contrib/lite/experimental/microfrontend/lib/log_lut.h new file mode 100644 index 0000000000000000000000000000000000000000..53dd1fa4052d2f7f26e4428772670af658f61878 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/log_lut.h @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_LUT_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_LUT_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// Number of segments in the log lookup table. The table will be kLogSegments+1 +// in length (with some padding). +#define kLogSegments 128 +#define kLogSegmentsLog2 7 + +// Scale used by lookup table. +#define kLogScale 65536 +#define kLogScaleLog2 16 +#define kLogCoeff 45426 + +extern const uint16_t kLogLut[]; + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_LUT_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale.c b/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale.c new file mode 100644 index 0000000000000000000000000000000000000000..4b1246187155e82cd9495a9a4020483526fcbd2b --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale.c @@ -0,0 +1,83 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale.h" + +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/bits.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/log_lut.h" + +#define kuint16max 0x0000FFFF + +// The following functions implement integer logarithms of various sizes. The +// approximation is calculated according to method described in +// www.inti.gob.ar/electronicaeinformatica/instrumentacion/utic/ +// publicaciones/SPL2007/Log10-spl07.pdf +// It first calculates log2 of the input and then converts it to natural +// logarithm. + +static uint32_t Log2FractionPart(const uint32_t x, const uint32_t log2x) { + // Part 1 + int32_t frac = x - (1LL << log2x); + if (log2x < kLogScaleLog2) { + frac <<= kLogScaleLog2 - log2x; + } else { + frac >>= log2x - kLogScaleLog2; + } + // Part 2 + const uint32_t base_seg = frac >> (kLogScaleLog2 - kLogSegmentsLog2); + const uint32_t seg_unit = + (((uint32_t) 1) << kLogScaleLog2) >> kLogSegmentsLog2; + + const int32_t c0 = kLogLut[base_seg]; + const int32_t c1 = kLogLut[base_seg + 1]; + const int32_t seg_base = seg_unit * base_seg; + const int32_t rel_pos = ((c1 - c0) * (frac - seg_base)) >> kLogScaleLog2; + return frac + c0 + rel_pos; +} + +static uint32_t Log(const uint32_t x, const uint32_t scale_shift) { + const uint32_t integer = MostSignificantBit32(x) - 1; + const uint32_t fraction = Log2FractionPart(x, integer); + const uint32_t log2 = (integer << kLogScaleLog2) + fraction; + const uint32_t round = kLogScale / 2; + const uint32_t loge = + (((uint64_t) kLogCoeff) * log2 + round) >> kLogScaleLog2; + // Finally scale to our output scale + const uint32_t loge_scaled = ((loge << scale_shift) + round) >> kLogScaleLog2; + return loge_scaled; +} + +uint16_t* LogScaleApply(struct LogScaleState* state, uint32_t* signal, + int signal_size, int correction_bits) { + const int scale_shift = state->scale_shift; + uint16_t* output = (uint16_t*) signal; + uint16_t* ret = output; + for (int i = 0; i < signal_size; ++i) { + uint32_t value = *signal++; + if (state->enable_log) { + if (correction_bits < 0) { + value >>= -correction_bits; + } else { + value <<= correction_bits; + } + if (value > 1) { + value = Log(value, scale_shift); + } else { + value = 0; + } + } + *output++ = (value < kuint16max) ? value : kuint16max; + } + return ret; +} diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale.h b/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale.h new file mode 100644 index 0000000000000000000000000000000000000000..8fd6099933049210aab2402e447dba8ad87406fa --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +struct LogScaleState { + int enable_log; + int scale_shift; +}; + +// Applies a fixed point logarithm to the signal and converts it to 16 bit. Note +// that the signal array will be modified. +uint16_t* LogScaleApply(struct LogScaleState* state, uint32_t* signal, + int signal_size, int correction_bits); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_H_ diff --git a/tensorflow/contrib/batching/batch_scheduler.h b/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_io.c similarity index 59% rename from tensorflow/contrib/batching/batch_scheduler.h rename to tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_io.c index 8e94e1fd8b969d4fef8dbc8c322557f9da3833e6..f59cde951ca40ab1b0d4c8d4407789deb0df74f9 100644 --- a/tensorflow/contrib/batching/batch_scheduler.h +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_io.c @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_io.h" -#ifndef TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_ -#define TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_ - -#include "tensorflow/core/kernels/batching_util/batch_scheduler.h" - -#endif // TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_ +void LogScaleWriteMemmap(FILE* fp, const struct LogScaleState* state, + const char* variable) { + fprintf(fp, "%s->enable_log = %d;\n", variable, state->enable_log); + fprintf(fp, "%s->scale_shift = %d;\n", variable, state->scale_shift); +} diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_io.h b/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_io.h new file mode 100644 index 0000000000000000000000000000000000000000..5444303b2445ac72d1069eff5c50db8622d4b536 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_io.h @@ -0,0 +1,33 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_IO_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_IO_H_ + +#include + +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void LogScaleWriteMemmap(FILE* fp, const struct LogScaleState* state, + const char* variable); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_IO_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_test.cc b/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..312d7ea7406a14f52fb04b570f121307f5422b9b --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_test.cc @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_util.h" + +#include +#include + +namespace { + +const int kScaleShift = 6; +const int kCorrectionBits = -1; + +TEST(LogScaleTest, CheckOutputValues) { + struct LogScaleState state; + state.enable_log = true; + state.scale_shift = kScaleShift; + + uint32_t fake_signal[] = {3578, 1533}; + uint16_t* output = LogScaleApply(&state, fake_signal, + sizeof(fake_signal) / sizeof(fake_signal[0]), + kCorrectionBits); + + const uint16_t expected[] = {479, 425}; + for (int i = 0; i < sizeof(expected) / sizeof(expected[0]); ++i) { + EXPECT_EQ(output[i], expected[i]); + } +} + +TEST(LogScaleTest, CheckOutputValuesNoLog) { + struct LogScaleState state; + state.enable_log = false; + state.scale_shift = kScaleShift; + + uint32_t fake_signal[] = {85964, 45998}; + uint16_t* output = LogScaleApply(&state, fake_signal, + sizeof(fake_signal) / sizeof(fake_signal[0]), + kCorrectionBits); + + const uint16_t expected[] = {65535, 45998}; + for (int i = 0; i < sizeof(expected) / sizeof(expected[0]); ++i) { + EXPECT_EQ(output[i], expected[i]); + } +} + +} // namespace diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_util.c b/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_util.c new file mode 100644 index 0000000000000000000000000000000000000000..8a025fbf72d9db349d9459e1d1fd76858b97a661 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_util.c @@ -0,0 +1,27 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_util.h" + +void LogScaleFillConfigWithDefaults(struct LogScaleConfig* config) { + config->enable_log = 1; + config->scale_shift = 6; +} + +int LogScalePopulateState(const struct LogScaleConfig* config, + struct LogScaleState* state) { + state->enable_log = config->enable_log; + state->scale_shift = config->scale_shift; + return 1; +} diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_util.h b/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_util.h new file mode 100644 index 0000000000000000000000000000000000000000..33b21f30b1093077f640f1d0c15dc336e72d8c31 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_util.h @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_UTIL_H_ + +#include +#include + +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale.h" + +#ifdef __cplusplus +extern "C" { +#endif + +struct LogScaleConfig { + // set to false (0) to disable this module + int enable_log; + // scale results by 2^(scale_shift) + int scale_shift; +}; + +// Populates the LogScaleConfig with "sane" default values. +void LogScaleFillConfigWithDefaults(struct LogScaleConfig* config); + +// Allocates any buffers. +int LogScalePopulateState(const struct LogScaleConfig* config, + struct LogScaleState* state); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_UTIL_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction.c b/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction.c new file mode 100644 index 0000000000000000000000000000000000000000..92f8b58d74f9d361af068eec19a5530a22a7715c --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction.c @@ -0,0 +1,51 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction.h" + +#include + +void NoiseReductionApply(struct NoiseReductionState* state, uint32_t* signal) { + int i; + for (i = 0; i < state->num_channels; ++i) { + const uint32_t smoothing = + ((i & 1) == 0) ? state->even_smoothing : state->odd_smoothing; + const uint32_t one_minus_smoothing = (1 << kNoiseReductionBits) - smoothing; + + // Update the estimate of the noise. + const uint32_t signal_scaled_up = signal[i] << state->smoothing_bits; + uint32_t estimate = + (((uint64_t) signal_scaled_up * smoothing) + + ((uint64_t) state->estimate[i] * one_minus_smoothing)) >> + kNoiseReductionBits; + state->estimate[i] = estimate; + + // Make sure that we can't get a negative value for the signal - estimate. + if (estimate > signal_scaled_up) { + estimate = signal_scaled_up; + } + + const uint32_t floor = + ((uint64_t) signal[i] * state->min_signal_remaining) >> + kNoiseReductionBits; + const uint32_t subtracted = (signal_scaled_up - estimate) >> + state->smoothing_bits; + const uint32_t output = subtracted > floor ? subtracted : floor; + signal[i] = output; + } +} + +void NoiseReductionReset(struct NoiseReductionState* state) { + memset(state->estimate, 0, sizeof(*state->estimate) * state->num_channels); +} diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction.h b/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction.h new file mode 100644 index 0000000000000000000000000000000000000000..cc2cf2d9b742f9d9a0e12700e4bdf0a74ea589d7 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_H_ + +#define kNoiseReductionBits 14 + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +struct NoiseReductionState { + int smoothing_bits; + uint16_t even_smoothing; + uint16_t odd_smoothing; + uint16_t min_signal_remaining; + int num_channels; + uint32_t* estimate; +}; + +// Removes stationary noise from each channel of the signal using a low pass +// filter. +void NoiseReductionApply(struct NoiseReductionState* state, uint32_t* signal); + +void NoiseReductionReset(struct NoiseReductionState* state); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_io.c b/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_io.c new file mode 100644 index 0000000000000000000000000000000000000000..1cba410436ad2b541207cc2ba7aaf6d54e71d172 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_io.c @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_io.h" + +void NoiseReductionWriteMemmapPreamble( + FILE* fp, const struct NoiseReductionState* state) { + fprintf(fp, "static uint32_t noise_reduction_estimate[%zu];\n", + state->num_channels); + fprintf(fp, "\n"); +} + +void NoiseReductionWriteMemmap(FILE* fp, + const struct NoiseReductionState* state, + const char* variable) { + fprintf(fp, "%s->even_smoothing = %d;\n", variable, state->even_smoothing); + fprintf(fp, "%s->odd_smoothing = %d;\n", variable, state->odd_smoothing); + fprintf(fp, "%s->min_signal_remaining = %d;\n", variable, + state->min_signal_remaining); + fprintf(fp, "%s->num_channels = %d;\n", variable, state->num_channels); + + fprintf(fp, "%s->estimate = noise_reduction_estimate;\n", variable); +} diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_io.h b/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_io.h new file mode 100644 index 0000000000000000000000000000000000000000..afeedfce99d09b934e7d63ff3ce60cf4928f3c11 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_io.h @@ -0,0 +1,36 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_IO_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_IO_H_ + +#include + +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void NoiseReductionWriteMemmapPreamble(FILE* fp, + const struct NoiseReductionState* state); +void NoiseReductionWriteMemmap(FILE* fp, + const struct NoiseReductionState* state, + const char* variable); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_IO_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_test.cc b/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f4cf486227a2607774a82dd45c9d9ee4c2bcdb16 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_test.cc @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_util.h" + +#include +#include + +namespace { + +const int kNumChannels = 2; + +// Test noise reduction using default config values. +class NoiseReductionTest : public ::testing::Test { + protected: + NoiseReductionTest() { + config_.smoothing_bits = 10; + config_.even_smoothing = 0.025; + config_.odd_smoothing = 0.06; + config_.min_signal_remaining = 0.05; + } + + struct NoiseReductionConfig config_; +}; + +TEST_F(NoiseReductionTest, TestNoiseReductionEstimate) { + struct NoiseReductionState state; + ASSERT_TRUE(NoiseReductionPopulateState(&config_, &state, kNumChannels)); + + uint32_t signal[] = {247311, 508620}; + NoiseReductionApply(&state, signal); + + const uint32_t expected[] = {6321887, 31248341}; + ASSERT_EQ(state.num_channels, sizeof(expected) / sizeof(expected[0])); + for (int i = 0; i < state.num_channels; ++i) { + EXPECT_EQ(state.estimate[i], expected[i]); + } + + NoiseReductionFreeStateContents(&state); +} + +TEST_F(NoiseReductionTest, TestNoiseReduction) { + struct NoiseReductionState state; + ASSERT_TRUE(NoiseReductionPopulateState(&config_, &state, kNumChannels)); + + uint32_t signal[] = {247311, 508620}; + NoiseReductionApply(&state, signal); + + const uint32_t expected[] = {241137, 478104}; + ASSERT_EQ(state.num_channels, sizeof(expected) / sizeof(expected[0])); + for (int i = 0; i < state.num_channels; ++i) { + EXPECT_EQ(signal[i], expected[i]); + } + + NoiseReductionFreeStateContents(&state); +} + +} // namespace diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_util.c b/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_util.c new file mode 100644 index 0000000000000000000000000000000000000000..46f475352e0670a2618e3796c7495343b32d013e --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_util.c @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_util.h" + +#include + +void NoiseReductionFillConfigWithDefaults(struct NoiseReductionConfig* config) { + config->smoothing_bits = 10; + config->even_smoothing = 0.025; + config->odd_smoothing = 0.06; + config->min_signal_remaining = 0.05; +} + +int NoiseReductionPopulateState(const struct NoiseReductionConfig* config, + struct NoiseReductionState* state, + int num_channels) { + state->smoothing_bits = config->smoothing_bits; + state->odd_smoothing = config->odd_smoothing * (1 << kNoiseReductionBits); + state->even_smoothing = config->even_smoothing * (1 << kNoiseReductionBits); + state->min_signal_remaining = + config->min_signal_remaining * (1 << kNoiseReductionBits); + state->num_channels = num_channels; + state->estimate = calloc(state->num_channels, sizeof(*state->estimate)); + if (state->estimate == NULL) { + fprintf(stderr, "Failed to alloc estimate buffer\n"); + return 0; + } + return 1; +} + +void NoiseReductionFreeStateContents(struct NoiseReductionState* state) { + free(state->estimate); +} diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_util.h b/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_util.h new file mode 100644 index 0000000000000000000000000000000000000000..207b8a679dac9d3cf2e933f499477b2e03c12cb4 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_util.h @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_UTIL_H_ + +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction.h" + +#ifdef __cplusplus +extern "C" { +#endif + +struct NoiseReductionConfig { + // scale the signal up by 2^(smoothing_bits) before reduction + int smoothing_bits; + // smoothing coefficient for even-numbered channels + float even_smoothing; + // smoothing coefficient for odd-numbered channels + float odd_smoothing; + // fraction of signal to preserve (1.0 disables this module) + float min_signal_remaining; +}; + +// Populates the NoiseReductionConfig with "sane" default values. +void NoiseReductionFillConfigWithDefaults(struct NoiseReductionConfig* config); + +// Allocates any buffers. +int NoiseReductionPopulateState(const struct NoiseReductionConfig* config, + struct NoiseReductionState* state, + int num_channels); + +// Frees any allocated buffers. +void NoiseReductionFreeStateContents(struct NoiseReductionState* state); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_UTIL_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control.c b/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control.c new file mode 100644 index 0000000000000000000000000000000000000000..551d552e8f63a487debdacba63fb11cb67de0e2f --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control.c @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control.h" + +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/bits.h" + +int16_t WideDynamicFunction(const uint32_t x, const int16_t* lut) { + if (x <= 2) { + return lut[x]; + } + + const int16_t interval = MostSignificantBit32(x); + lut += 4 * interval - 6; + + const int16_t frac = ((interval < 11) + ? (x << (11 - interval)) + : (x >> (interval - 11)) + ) & 0x3FF; + + int32_t result = ((int32_t) lut[2] * frac) >> 5; + result += ((int32_t) lut[1]) << 5; + result *= frac; + result = (result + (1 << 14)) >> 15; + result += lut[0]; + return (int16_t) result; +} + +uint32_t PcanShrink(const uint32_t x) { + if (x < (2 << kPcanSnrBits)) { + return (x * x) >> (2 + 2 * kPcanSnrBits - kPcanOutputBits); + } else { + return (x >> (kPcanSnrBits - kPcanOutputBits)) - (1 << kPcanOutputBits); + } +} + +void PcanGainControlApply(struct PcanGainControlState* state, + uint32_t* signal) { + for (int i = 0; i < state->num_channels; ++i) { + const uint32_t gain = WideDynamicFunction(state->noise_estimate[i], + state->gain_lut); + const uint32_t snr = ((uint64_t) signal[i] * gain) >> state->snr_shift; + signal[i] = PcanShrink(snr); + } +} diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control.h b/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control.h new file mode 100644 index 0000000000000000000000000000000000000000..cab74f49dbece640cc2925607260ba74a95a76a1 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_PCAN_GAIN_CONTROL_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_PCAN_GAIN_CONTROL_H_ + +#include +#include + +#define kPcanSnrBits 12 +#define kPcanOutputBits 6 + +#ifdef __cplusplus +extern "C" { +#endif + +struct PcanGainControlState { + int enable_pcan; + uint32_t* noise_estimate; + int num_channels; + int16_t* gain_lut; + int32_t snr_shift; +}; + +int16_t WideDynamicFunction(const uint32_t x, const int16_t* lut); + +uint32_t PcanShrink(const uint32_t x); + +void PcanGainControlApply(struct PcanGainControlState* state, uint32_t* signal); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_PCAN_GAIN_CONTROL_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control_test.cc b/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bbc36d6eac7757e3e7d3db7691aa11785c6deb37 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control_test.cc @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control_util.h" + +#include +#include + +namespace { + +const int kNumChannels = 2; +const int kSmoothingBits = 10; +const int kCorrectionBits = -1; + +// Test pcan auto gain control using default config values. +class PcanGainControlTest : public ::testing::Test { + protected: + PcanGainControlTest() { + config_.enable_pcan = 1; + config_.strength = 0.95; + config_.offset = 80.0; + config_.gain_bits = 21; + } + + struct PcanGainControlConfig config_; +}; + +TEST_F(PcanGainControlTest, TestPcanGainControl) { + uint32_t estimate[] = {6321887, 31248341}; + struct PcanGainControlState state; + ASSERT_TRUE(PcanGainControlPopulateState(&config_, &state, estimate, + kNumChannels, kSmoothingBits, + kCorrectionBits)); + + uint32_t signal[] = {241137, 478104}; + PcanGainControlApply(&state, signal); + + const uint32_t expected[] = {3578, 1533}; + ASSERT_EQ(state.num_channels, sizeof(expected) / sizeof(expected[0])); + for (int i = 0; i < state.num_channels; ++i) { + EXPECT_EQ(signal[i], expected[i]); + } + + PcanGainControlFreeStateContents(&state); +} + +} // namespace diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control_util.c b/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control_util.c new file mode 100644 index 0000000000000000000000000000000000000000..4226b390bc1427ae2131ec11db7338723f0e6ba9 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control_util.c @@ -0,0 +1,90 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control_util.h" + +#include +#include + +#define kint16max 0x00007FFF + +void PcanGainControlFillConfigWithDefaults( + struct PcanGainControlConfig* config) { + config->enable_pcan = 0; + config->strength = 0.95; + config->offset = 80.0; + config->gain_bits = 21; +} + +int16_t PcanGainLookupFunction(const struct PcanGainControlConfig* config, + int32_t input_bits, uint32_t x) { + const float x_as_float = ((float) x) / ((uint32_t) 1 << input_bits); + const float gain_as_float = ((uint32_t) 1 << config->gain_bits) * + powf(x_as_float + config->offset, -config->strength); + + if (gain_as_float > kint16max) { + return kint16max; + } + return (int16_t) (gain_as_float + 0.5f); +} + +int PcanGainControlPopulateState(const struct PcanGainControlConfig* config, + struct PcanGainControlState* state, + uint32_t* noise_estimate, + const int num_channels, + const uint16_t smoothing_bits, + const int32_t input_correction_bits) { + state->enable_pcan = config->enable_pcan; + if (!state->enable_pcan) { + return 1; + } + state->noise_estimate = noise_estimate; + state->num_channels = num_channels; + state->gain_lut = malloc(kWideDynamicFunctionLUTSize * sizeof(int16_t)); + if (state->gain_lut == NULL) { + fprintf(stderr, "Failed to allocate gain LUT\n"); + return 0; + } + state->snr_shift = config->gain_bits - input_correction_bits - kPcanSnrBits; + + const int32_t input_bits = smoothing_bits - input_correction_bits; + state->gain_lut[0] = PcanGainLookupFunction(config, input_bits, 0); + state->gain_lut[1] = PcanGainLookupFunction(config, input_bits, 1); + state->gain_lut -= 6; + for (int interval = 2; interval <= kWideDynamicFunctionBits; ++interval) { + const uint32_t x0 = (uint32_t) 1 << (interval - 1); + const uint32_t x1 = x0 + (x0 >> 1); + const uint32_t x2 = (interval == kWideDynamicFunctionBits) + ? x0 + (x0 - 1) : 2 * x0; + + const int16_t y0 = PcanGainLookupFunction(config, input_bits, x0); + const int16_t y1 = PcanGainLookupFunction(config, input_bits, x1); + const int16_t y2 = PcanGainLookupFunction(config, input_bits, x2); + + const int32_t diff1 = (int32_t) y1 - y0; + const int32_t diff2 = (int32_t) y2 - y0; + const int32_t a1 = 4 * diff1 - diff2; + const int32_t a2 = diff2 - a1; + + state->gain_lut[4 * interval] = y0; + state->gain_lut[4 * interval + 1] = (int16_t) a1; + state->gain_lut[4 * interval + 2] = (int16_t) a2; + } + state->gain_lut += 6; + return 1; +} + +void PcanGainControlFreeStateContents(struct PcanGainControlState* state) { + free(state->gain_lut); +} diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control_util.h b/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control_util.h new file mode 100644 index 0000000000000000000000000000000000000000..79c0b1da693651c438ed8b5f85c19ba23850e738 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control_util.h @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_PCAN_GAIN_CONTROL_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_PCAN_GAIN_CONTROL_UTIL_H_ + +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control.h" + +#define kWideDynamicFunctionBits 32 +#define kWideDynamicFunctionLUTSize (4 * kWideDynamicFunctionBits - 3) + +#ifdef __cplusplus +extern "C" { +#endif + +struct PcanGainControlConfig { + // set to false (0) to disable this module + int enable_pcan; + // gain normalization exponent (0.0 disables, 1.0 full strength) + float strength; + // positive value added in the normalization denominator + float offset; + // number of fractional bits in the gain + int gain_bits; +}; + +void PcanGainControlFillConfigWithDefaults( + struct PcanGainControlConfig* config); + +int16_t PcanGainLookupFunction(const struct PcanGainControlConfig* config, + int32_t input_bits, uint32_t x); + +int PcanGainControlPopulateState(const struct PcanGainControlConfig* config, + struct PcanGainControlState* state, + uint32_t* noise_estimate, + const int num_channels, + const uint16_t smoothing_bits, + const int32_t input_correction_bits); + +void PcanGainControlFreeStateContents(struct PcanGainControlState* state); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_PCAN_GAIN_CONTROL_UTIL_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/window.c b/tensorflow/contrib/lite/experimental/microfrontend/lib/window.c new file mode 100644 index 0000000000000000000000000000000000000000..0fdc040a7a58c7d3bd3ef8eb385e0c0c3d415236 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/window.c @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/window.h" + +#include + +int WindowProcessSamples(struct WindowState* state, const int16_t* samples, + size_t num_samples, size_t* num_samples_read) { + const int size = state->size; + + // Copy samples from the samples buffer over to our local input. + size_t max_samples_to_copy = state->size - state->input_used; + if (max_samples_to_copy > num_samples) { + max_samples_to_copy = num_samples; + } + memcpy(state->input + state->input_used, samples, + max_samples_to_copy * sizeof(*samples)); + *num_samples_read = max_samples_to_copy; + state->input_used += max_samples_to_copy; + + if (state->input_used < state->size) { + // We don't have enough samples to compute a window. + return 0; + } + + // Apply the window to the input. + const int16_t* coefficients = state->coefficients; + const int16_t* input = state->input; + int16_t* output = state->output; + int i; + int16_t max_abs_output_value = 0; + for (i = 0; i < size; ++i) { + int16_t new_value = + (((int32_t) *input++) * *coefficients++) >> kFrontendWindowBits; + *output++ = new_value; + if (new_value < 0) { + new_value = -new_value; + } + if (new_value > max_abs_output_value) { + max_abs_output_value = new_value; + } + } + // Shuffle the input down by the step size, and update how much we have used. + memmove(state->input, state->input + state->step, + sizeof(*state->input) * (state->size - state->step)); + state->input_used -= state->step; + state->max_abs_output_value = max_abs_output_value; + + // Indicate that the output buffer is valid for the next stage. + return 1; +} + +void WindowReset(struct WindowState* state) { + memset(state->input, 0, state->size * sizeof(*state->input)); + memset(state->output, 0, state->size * sizeof(*state->output)); + state->input_used = 0; + state->max_abs_output_value = 0; +} diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/window.h b/tensorflow/contrib/lite/experimental/microfrontend/lib/window.h new file mode 100644 index 0000000000000000000000000000000000000000..90291e5c7238b5c789d46910f256686be3edef7c --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/window.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_H_ + +#include +#include + +#define kFrontendWindowBits 12 + +#ifdef __cplusplus +extern "C" { +#endif + +struct WindowState { + size_t size; + int16_t* coefficients; + size_t step; + + int16_t* input; + size_t input_used; + int16_t* output; + int16_t max_abs_output_value; +}; + +// Applies a window to the samples coming in, stepping forward at the given +// rate. +int WindowProcessSamples(struct WindowState* state, const int16_t* samples, + size_t num_samples, size_t* num_samples_read); + +void WindowReset(struct WindowState* state); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/window_io.c b/tensorflow/contrib/lite/experimental/microfrontend/lib/window_io.c new file mode 100644 index 0000000000000000000000000000000000000000..f1fee7c1eda9d0002913bf3f214c9a90517c1f15 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/window_io.c @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/window_io.h" + +void WindowWriteMemmapPreamble(FILE* fp, const struct WindowState* state) { + fprintf(fp, "static int16_t window_coefficients[] = {\n"); + for (int i = 0; i < state->size; ++i) { + fprintf(fp, "%d", state->coefficients[i]); + if (i < state->size - 1) { + fprintf(fp, ", "); + } + } + fprintf(fp, "};\n"); + fprintf(fp, "static int16_t window_input[%zu];\n", state->size); + fprintf(fp, "static int16_t window_output[%zu];\n", state->size); + fprintf(fp, "\n"); +} + +void WindowWriteMemmap(FILE* fp, const struct WindowState* state, + const char* variable) { + fprintf(fp, "%s->size = %zu;\n", variable, state->size); + fprintf(fp, "%s->coefficients = window_coefficients;\n", variable); + fprintf(fp, "%s->step = %zu;\n", variable, state->step); + + fprintf(fp, "%s->input = window_input;\n", variable); + fprintf(fp, "%s->input_used = %zu;\n", variable, state->input_used); + fprintf(fp, "%s->output = window_output;\n", variable); + fprintf(fp, "%s->max_abs_output_value = %d;\n", variable, + state->max_abs_output_value); +} diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/window_io.h b/tensorflow/contrib/lite/experimental/microfrontend/lib/window_io.h new file mode 100644 index 0000000000000000000000000000000000000000..2bab9064c1fa70154b3bc8d6674f1ce0f4f95486 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/window_io.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_IO_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_IO_H_ + +#include + +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/window.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void WindowWriteMemmapPreamble(FILE* fp, const struct WindowState* state); +void WindowWriteMemmap(FILE* fp, const struct WindowState* state, + const char* variable); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_IO_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/window_test.cc b/tensorflow/contrib/lite/experimental/microfrontend/lib/window_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a6c0879faa8ac544a426177213f1fafaf4ee54fe --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/window_test.cc @@ -0,0 +1,157 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/window.h" +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/window_util.h" + +#include +#include + +namespace { + +const int kSampleRate = 1000; +const int kWindowSamples = 25; +const int kStepSamples = 10; +const int16_t kFakeAudioData[] = { + 0, 32767, 0, -32768, 0, 32767, 0, -32768, 0, 32767, 0, -32768, + 0, 32767, 0, -32768, 0, 32767, 0, -32768, 0, 32767, 0, -32768, + 0, 32767, 0, -32768, 0, 32767, 0, -32768, 0, 32767, 0, -32768}; + +// Test window function behaviors using default config values. +class WindowTest : public ::testing::Test { + protected: + WindowTest() { + config_.size_ms = 25; + config_.step_size_ms = 10; + } + + struct WindowConfig config_; +}; + +TEST_F(WindowTest, CheckCoefficients) { + struct WindowState state; + ASSERT_TRUE(WindowPopulateState(&config_, &state, kSampleRate)); + + const int16_t expected[] = {16, 144, 391, 743, 1176, 1664, 2177, + 2681, 3145, 3541, 3843, 4032, 4096, 4032, + 3843, 3541, 3145, 2681, 2177, 1664, 1176, + 743, 391, 144, 16}; + ASSERT_EQ(state.size, sizeof(expected) / sizeof(expected[0])); + for (int i = 0; i < state.size; ++i) { + EXPECT_EQ(state.coefficients[i], expected[i]); + } + + WindowFreeStateContents(&state); +} + +TEST_F(WindowTest, CheckResidualInput) { + struct WindowState state; + ASSERT_TRUE(WindowPopulateState(&config_, &state, kSampleRate)); + size_t num_samples_read; + + ASSERT_TRUE(WindowProcessSamples( + &state, kFakeAudioData, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]), &num_samples_read)); + + for (int i = kStepSamples; i < kWindowSamples; ++i) { + EXPECT_EQ(state.input[i - kStepSamples], kFakeAudioData[i]); + } + + WindowFreeStateContents(&state); +} + +TEST_F(WindowTest, CheckOutputValues) { + struct WindowState state; + ASSERT_TRUE(WindowPopulateState(&config_, &state, kSampleRate)); + size_t num_samples_read; + + ASSERT_TRUE(WindowProcessSamples( + &state, kFakeAudioData, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]), &num_samples_read)); + + const int16_t expected[] = { + 0, 1151, 0, -5944, 0, 13311, 0, -21448, 0, 28327, 0, -32256, 0, 32255, + 0, -28328, 0, 21447, 0, -13312, 0, 5943, 0, -1152, 0}; + ASSERT_EQ(state.size, sizeof(expected) / sizeof(expected[0])); + for (int i = 0; i < state.size; ++i) { + EXPECT_EQ(state.output[i], expected[i]); + } + + WindowFreeStateContents(&state); +} + +TEST_F(WindowTest, CheckMaxAbsValue) { + struct WindowState state; + ASSERT_TRUE(WindowPopulateState(&config_, &state, kSampleRate)); + size_t num_samples_read; + + ASSERT_TRUE(WindowProcessSamples( + &state, kFakeAudioData, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]), &num_samples_read)); + + EXPECT_EQ(state.max_abs_output_value, 32256); + + WindowFreeStateContents(&state); +} + +TEST_F(WindowTest, CheckConsecutiveWindow) { + struct WindowState state; + ASSERT_TRUE(WindowPopulateState(&config_, &state, kSampleRate)); + size_t num_samples_read; + + ASSERT_TRUE(WindowProcessSamples( + &state, kFakeAudioData, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]), &num_samples_read)); + ASSERT_TRUE(WindowProcessSamples( + &state, kFakeAudioData + kWindowSamples, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]) - kWindowSamples, + &num_samples_read)); + + const int16_t expected[] = { + 0, -1152, 0, 5943, 0, -13312, 0, 21447, 0, -28328, 0, 32255, 0, -32256, + 0, 28327, 0, -21448, 0, 13311, 0, -5944, 0, 1151, 0}; + ASSERT_EQ(state.size, sizeof(expected) / sizeof(expected[0])); + for (int i = 0; i < state.size; ++i) { + EXPECT_EQ(state.output[i], expected[i]); + } + + WindowFreeStateContents(&state); +} + +TEST_F(WindowTest, CheckNotEnoughSamples) { + struct WindowState state; + ASSERT_TRUE(WindowPopulateState(&config_, &state, kSampleRate)); + size_t num_samples_read; + + ASSERT_TRUE(WindowProcessSamples( + &state, kFakeAudioData, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]), &num_samples_read)); + ASSERT_TRUE(WindowProcessSamples( + &state, kFakeAudioData + kWindowSamples, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]) - kWindowSamples, + &num_samples_read)); + ASSERT_FALSE(WindowProcessSamples( + &state, kFakeAudioData + kWindowSamples + kStepSamples, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]) - kWindowSamples - + kStepSamples, + &num_samples_read)); + + EXPECT_EQ( + state.input_used, + sizeof(kFakeAudioData) / sizeof(kFakeAudioData[0]) - 2 * kStepSamples); + + WindowFreeStateContents(&state); +} + +} // namespace diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/window_util.c b/tensorflow/contrib/lite/experimental/microfrontend/lib/window_util.c new file mode 100644 index 0000000000000000000000000000000000000000..3adde0fb0a6855ef8372f80b55df4f4d83fc9dc7 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/window_util.c @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/window_util.h" + +#include +#include +#include +#include + +void WindowFillConfigWithDefaults(struct WindowConfig* config) { + config->size_ms = 25; + config->step_size_ms = 10; +} + +int WindowPopulateState(const struct WindowConfig* config, + struct WindowState* state, int sample_rate) { + state->size = config->size_ms * sample_rate / 1000; + state->step = config->step_size_ms * sample_rate / 1000; + + state->coefficients = malloc( + state->size * sizeof(*state->coefficients)); + if (state->coefficients == NULL) { + fprintf(stderr, "Failed to allocate window coefficients\n"); + return 0; + } + + // Populate the window values. + const float arg = M_PI * 2.0 / ((float) state->size); + int i; + for (i = 0; i < state->size; ++i) { + float float_value = 0.5 - (0.5 * cos(arg * (i + 0.5))); + // Scale it to fixed point and round it. + state->coefficients[i] = + floor(float_value * (1 << kFrontendWindowBits) + 0.5); + } + + state->input_used = 0; + state->input = malloc( + state->size * sizeof(*state->input)); + if (state->input == NULL) { + fprintf(stderr, "Failed to allocate window input\n"); + return 0; + } + + state->output = malloc( + state->size * sizeof(*state->output)); + if (state->output == NULL) { + fprintf(stderr, "Failed to allocate window output\n"); + return 0; + } + + return 1; +} + +void WindowFreeStateContents(struct WindowState* state) { + free(state->coefficients); + free(state->input); + free(state->output); +} diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/window_util.h b/tensorflow/contrib/lite/experimental/microfrontend/lib/window_util.h new file mode 100644 index 0000000000000000000000000000000000000000..52dc8f38cc8bfc799f1f35a1d1a3b894bdfcf944 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/microfrontend/lib/window_util.h @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_UTIL_H_ + +#include "tensorflow/contrib/lite/experimental/microfrontend/lib/window.h" + +#ifdef __cplusplus +extern "C" { +#endif + +struct WindowConfig { + // length of window frame in milliseconds + size_t size_ms; + // length of step for next frame in milliseconds + size_t step_size_ms; +}; + +// Populates the WindowConfig with "sane" default values. +void WindowFillConfigWithDefaults(struct WindowConfig* config); + +// Allocates any buffers. +int WindowPopulateState(const struct WindowConfig* config, + struct WindowState* state, int sample_rate); + +// Frees any allocated buffers. +void WindowFreeStateContents(struct WindowState* state); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_UTIL_H_ diff --git a/tensorflow/contrib/lite/g3doc/_book.yaml b/tensorflow/contrib/lite/g3doc/_book.yaml index de6914e5366acef53a853a73f791dcfa801d444c..05c65441c3db4e74b6e7834437fa9cd0633af636 100644 --- a/tensorflow/contrib/lite/g3doc/_book.yaml +++ b/tensorflow/contrib/lite/g3doc/_book.yaml @@ -39,6 +39,16 @@ upper_tabs: - title: TensorFlow Lite for Raspberry Pi path: /lite/rpi + - heading: TF Lite converter + - title: Overview + path: /lite/convert/ + - title: Python API guide + path: /lite/convert/python_api + - title: Command line examples + path: /lite/convert/cmdline_examples + - title: Command line reference + path: /lite/convert/cmdline_reference + - title: TF Mobile style: accordion status: deprecated diff --git a/tensorflow/contrib/lite/g3doc/_index.yaml b/tensorflow/contrib/lite/g3doc/_index.yaml index bc66cc5dc1606537b7e186f3c825ab8335aa9e91..44ee6ba7505d421e46c8806ea5ca0ed4bc07f147 100644 --- a/tensorflow/contrib/lite/g3doc/_index.yaml +++ b/tensorflow/contrib/lite/g3doc/_index.yaml @@ -97,7 +97,7 @@ landing_page: path: https://www.shazam.com/ - custom_image: path: ./images/landing-page/nest_logo.png - path: https://nest.com/ + path: https://nest.com/ - custom_image: path: ./images/landing-page/loseit_logo.png path: https://www.loseit.com/ @@ -129,10 +129,10 @@ landing_page: icon_name: autorenew description: > Convert a TensorFlow model into a compressed flat buffer with the - TensorFlow Lite Optimizing Converter (TOCO). + TensorFlow Lite Converter. buttons: - - label: Read the TOCO guide - path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/python_api.md + - label: Read the converter guide + path: /lite/convert/ classname: button button-primary tfo-button-primary - heading: Deploy icon: diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md b/tensorflow/contrib/lite/g3doc/convert/cmdline_examples.md similarity index 76% rename from tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md rename to tensorflow/contrib/lite/g3doc/convert/cmdline_examples.md index aba7536cbd3fbec509390158896e078e6379c848..44fb4f19aeb12fa83f76b6373bcbc148561d0747 100644 --- a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md +++ b/tensorflow/contrib/lite/g3doc/convert/cmdline_examples.md @@ -1,57 +1,33 @@ -# TensorFlow Lite Optimizing Converter command-line examples - -This page provides examples on how to use TOCO via command line. It is -complemented by the following documents: - -* [README](../README.md) -* [Command-line glossary](cmdline_reference.md) -* [Python API examples](python_api.md) - -Table of contents: - -* [Command-line tools](#tools) - * [Converting models prior to TensorFlow 1.9.](#pre-tensorflow-1.9) -* [Basic examples](#basic) - * [Convert a TensorFlow GraphDef](#graphdef) - * [Convert a TensorFlow SavedModel](#savedmodel) - * [Convert a tf.keras model](#keras) -* [Quantization](#quantization) - * [Convert a TensorFlow GraphDef for quantized inference](#graphdef-quant) - * [Use "dummy-quantization" to try out quantized inference on a float - graph](#dummy-quant) -* [Specifying input and output arrays](#specifying-input-and-output-arrays) - * [Multiple input arrays](#multiple-input-arrays) - * [Multiple output arrays](#multiple-output-arrays) - * [Specifying subgraphs](#specifying-subgraphs) -* [Graph visualizations](#graph-visualizations) - * [Using --output_format=GRAPHVIZ_DOT](#using-output-format-graphviz-dot) - * [Using --dump_graphviz_dir](#using-dump-graphviz-dir) - * [Graph "video" logging](#graph-video-logging) - * [Legend for the graph visualizations](#graphviz-legend) +# Converter command-line examples + +This page shows how to use the TensorFlow Lite Converter in the command line. + +[TOC] ## Command-line tools -There are two approaches to running TOCO via command line. +There are two approaches to running the converter in the command line. * `tflite_convert`: Starting from TensorFlow 1.9, the command-line tool - `tflite_convert` will be installed as part of the Python package. All of the + `tflite_convert` is installed as part of the Python package. All of the examples below use `tflite_convert` for simplicity. * Example: `tflite_convert --output_file=...` -* `bazel`: In order to run the latest version of TOCO, [clone the TensorFlow - repository](https://www.tensorflow.org/install/source) - and use `bazel`. This is the recommended approach for converting models that - utilize new features that were not supported by TOCO in TensorFlow 1.9. +* `bazel`: In order to run the latest version of the TensorFlow Lite Converter + either install the nightly build using + [pip](https://www.tensorflow.org/install/pip) or + [clone the TensorFlow repository](https://www.tensorflow.org/install/source) + and use `bazel`. * Example: `bazel run //tensorflow/contrib/lite/python:tflite_convert -- --output_file=...` -### Converting models prior to TensorFlow 1.9. +### Converting models prior to TensorFlow 1.9 -The recommended approach for using TOCO prior to TensorFlow 1.9 is the [Python -API](python_api.md#pre-tensorflow-1.9). If a command line tool is desired, the -`toco` command line tool was available in TensorFlow 1.7. Enter `toco --help` in -Terminal for additional details on the command-line flags available. There were -no command line tools in TensorFlow 1.8. +The recommended approach for using the converter prior to TensorFlow 1.9 is the +[Python API](python_api.md#pre_tensorflow_1.9). If a command line tool is +desired, the `toco` command line tool was available in TensorFlow 1.7. Enter +`toco --help` in Terminal for additional details on the command-line flags +available. There were no command line tools in TensorFlow 1.8. ## Basic examples @@ -115,11 +91,11 @@ tflite_convert \ ## Quantization -### Convert a TensorFlow GraphDef for quantized inference +### Convert a TensorFlow GraphDef for quantized inference -TOCO is compatible with fixed point quantization models described -[here](https://www.tensorflow.org/performance/quantization). These are float -models with +The TensorFlow Lite Converter is compatible with fixed point quantization models +described [here](https://www.tensorflow.org/performance/quantization). These are +float models with [`FakeQuant*`](https://www.tensorflow.org/api_guides/python/array_ops#Fake_quantization) ops inserted at the boundaries of fused layers to record min-max range information. This generates a quantized inference workload that reproduces the @@ -139,14 +115,14 @@ tflite_convert \ --std_dev_values=127 ``` -### Use \"dummy-quantization\" to try out quantized inference on a float graph +### Use \"dummy-quantization\" to try out quantized inference on a float graph -In order to evaluate the possible benefit of generating a quantized graph, TOCO -allows "dummy-quantization" on float graphs. The flags `--default_ranges_min` -and `--default_ranges_max` accept plausible values for the min-max ranges of the -values in all arrays that do not have min-max information. "Dummy-quantization" -will produce lower accuracy but will emulate the performance of a correctly -quantized model. +In order to evaluate the possible benefit of generating a quantized graph, the +converter allows "dummy-quantization" on float graphs. The flags +`--default_ranges_min` and `--default_ranges_max` accept plausible values for +the min-max ranges of the values in all arrays that do not have min-max +information. "Dummy-quantization" will produce lower accuracy but will emulate +the performance of a correctly quantized model. The example below contains a model using Relu6 activation functions. Therefore, a reasonable guess is that most activation ranges should be contained in [0, 6]. @@ -207,10 +183,10 @@ tflite_convert \ ### Specifying subgraphs Any array in the input file can be specified as an input or output array in -order to extract subgraphs out of an input graph file. TOCO discards the parts -of the graph outside of the specific subgraph. Use [graph -visualizations](#graph-visualizations) to identify the input and output arrays -that make up the desired subgraph. +order to extract subgraphs out of an input graph file. The TensorFlow Lite +Converter discards the parts of the graph outside of the specific subgraph. Use +[graph visualizations](#graph_visualizations) to identify the input and output +arrays that make up the desired subgraph. The follow command shows how to extract a single fused layer out of a TensorFlow GraphDef. @@ -247,11 +223,12 @@ function tends to get fused). ## Graph visualizations -TOCO can export a graph to the Graphviz Dot format for easy visualization via -either the `--output_format` flag or the `--dump_graphviz_dir` flag. The -subsections below outline the use cases for each. +The converter can export a graph to the Graphviz Dot format for easy +visualization using either the `--output_format` flag or the +`--dump_graphviz_dir` flag. The subsections below outline the use cases for +each. -### Using `--output_format=GRAPHVIZ_DOT` +### Using `--output_format=GRAPHVIZ_DOT` The first way to get a Graphviz rendering is to pass `GRAPHVIZ_DOT` into `--output_format`. This results in a plausible visualization of the graph. This @@ -323,10 +300,23 @@ As before, these can be rendered to PDFs: dot -Tpdf -O /tmp/toco_*.dot ``` -Sample output files can be seen here: - -* [toco_AT_IMPORT.dot.pdf](https://storage.googleapis.com/download.tensorflow.org/example_images/toco_AT_IMPORT.dot.pdf) -* [toco_AFTER_TRANSFORMATIONS.dot.pdf](https://storage.googleapis.com/download.tensorflow.org/example_images/toco_AFTER_TRANSFORMATIONS.dot.pdf). +Sample output files can be seen here below. Note that it is the same +`AveragePool` node in the top right of each image. + + + + + + +
+ + + + + + + +
beforeafter
### Graph "video" logging @@ -336,7 +326,7 @@ each individual graph transformation, resulting in thousands of files. Typically, one would then bisect into these files to understand when a given change was introduced in the graph. -### Legend for the graph visualizations +### Legend for the graph visualizations * Operators are red square boxes with the following hues of red: * Most operators are @@ -345,7 +335,7 @@ change was introduced in the graph. * Some typically heavy operators (e.g. Conv) are rendered in a darker red. -* Arrays are octogons with the following colors: +* Arrays are octagons with the following colors: * Constant arrays are blue. * Activation arrays are gray: diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md b/tensorflow/contrib/lite/g3doc/convert/cmdline_reference.md similarity index 91% rename from tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md rename to tensorflow/contrib/lite/g3doc/convert/cmdline_reference.md index 00bc8d4ccb8aedcfe701377419e6cd41d0b59855..d72a46760d48dae46d63f1e914d8afda3f527e27 100644 --- a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md +++ b/tensorflow/contrib/lite/g3doc/convert/cmdline_reference.md @@ -1,19 +1,10 @@ -# TensorFlow Lite Optimizing Converter command-line glossary +# Converter command-line reference -This page is complete reference of command-line flags used by TOCO's command -line starting from TensorFlow 1.9 up until the most recent build of TensorFlow. -It is complemented by the following other documents: +This page is complete reference of command-line flags used by the TensorFlow +Lite Converter's command line starting from TensorFlow 1.9 up until the most +recent build of TensorFlow. -* [README](../README.md) -* [Command-line examples](cmdline_examples.md) -* [Python API examples](python_api.md) - -Table of contents: - -* [High-level flags](#high-level-flags) -* [Model flags](#model-flags) -* [Transformation flags](#transformation-flags) -* [Logging flags](#logging-flags) +[TOC] ## High-level flags @@ -32,7 +23,7 @@ files. The flag `--output_file` is always required. Additionally, either * `--output_format`. Type: string. Default: `TFLITE`. Specifies the format of the output file. Allowed values: * `TFLITE`: TensorFlow Lite FlatBuffer format. - * `GRAPHVIZ_DOT`: GraphViz `.dot` format containg a visualization of the + * `GRAPHVIZ_DOT`: GraphViz `.dot` format containing a visualization of the graph after graph transformations. * Note that passing `GRAPHVIZ_DOT` to `--output_format` leads to loss of TFLite specific transformations. Therefore, the resulting @@ -68,7 +59,7 @@ based on index. * `--input_shapes`. Type: colon-separated list of comma-separated lists of integers. Each comma-separated list of integers gives the shape of one of the input arrays specified in - [TensorFlow convention](https://www.tensorflow.org/versions/r1.2/programmers_guide/dims_types#shape). + [TensorFlow convention](https://www.tensorflow.org/guide/tensors#shape). * Example: `--input_shapes=1,60,80,3` for a typical vision model means a batch size of 1, an input image height of 60, an input image width of 80, and an input image depth of 3 (representing RGB channels). diff --git a/tensorflow/contrib/lite/g3doc/convert/index.md b/tensorflow/contrib/lite/g3doc/convert/index.md new file mode 100644 index 0000000000000000000000000000000000000000..bc92a1c1a11a6f3808e44f37d04704ece1627fc3 --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/convert/index.md @@ -0,0 +1,19 @@ +# TensorFlow Lite Converter + +The TensorFlow Lite Converter takes a TensorFlow graph file and creates a graph +file used by the TensorFlow Lite interpreter. + +## From model training to device deployment + +After a TensorFlow model is trained, the TensorFlow Lite converter uses that +model to generate a TensorFlow Lite [FlatBuffer](https://google.github.io/flatbuffers/) +file (`.tflite`). The converter supports as input: +[SavedModels](https://www.tensorflow.org/guide/saved_model#using_savedmodel_with_estimators), +frozen graphs (models generated by +[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)), +and `tf.keras` models. The TensorFlow Lite `FlatBuffer` file is deployed to a +client device (generally a mobile or embedded device), and the TensorFlow Lite +interpreter uses the compressed model for on-device inference. This conversion +process is shown in the diagram below: + +![TFLite converter workflow](../images/convert/workflow.svg) diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/g3doc/convert/python_api.md similarity index 60% rename from tensorflow/contrib/lite/toco/g3doc/python_api.md rename to tensorflow/contrib/lite/g3doc/convert/python_api.md index 8c31c3dca865640ee1a60cbcc93b741f2d7d52cf..9dcb79187ec9bda487887327dbb575e8c580ba01 100644 --- a/tensorflow/contrib/lite/toco/g3doc/python_api.md +++ b/tensorflow/contrib/lite/g3doc/convert/python_api.md @@ -1,67 +1,43 @@ -# TensorFlow Lite Optimizing Converter & Interpreter Python API reference - -This page provides examples on how to use TOCO and the TensorFlow Lite -interpreter via the Python API. It is complemented by the following documents: - -* [README](../README.md) -* [Command-line examples](cmdline_examples.md) -* [Command-line glossary](cmdline_reference.md) - -Table of contents: - -* [High-level overview](#high-level-overview) -* [API](#api) -* [Basic examples](#basic) - * [Exporting a GraphDef from tf.Session](#basic-graphdef-sess) - * [Exporting a GraphDef from file](#basic-graphdef-file) - * [Exporting a SavedModel](#basic-savedmodel) - * [Exporting a tf.keras File](#basic-keras-file) -* [Complex examples](#complex) - * [Exporting a quantized GraphDef](#complex-quant) -* [TensorFlow Lite Python interpreter](#interpreter) - * [Using the interpreter from a model file](#interpreter-file) - * [Using the interpreter from model data](#interpreter-data) -* [Additional instructions](#additional-instructions) - * [Build from source code](#latest-package) - * [Converting models prior to TensorFlow 1.9.](#pre-tensorflow-1.9) +# Converter Python API guide + +This page provides examples on how to use the TensorFlow Lite Converter and the +TensorFlow Lite interpreter using the Python API. + +[TOC] + ## High-level overview -While the TensorFlow Lite Optimizing Converter can be used from the command -line, it is often convenient to use it as part of a Python model build and -training script. This is so that conversion can be part of your model -development pipeline. This allows you to know early and often that you are -designing a model that can be targeted to devices with mobile. +While the TensorFlow Lite Converter can be used from the command line, it is +often convenient to use in a Python script as part of the model development +pipeline. This allows you to know early that you are designing a model that can +be targeted to devices with mobile. ## API The API for converting TensorFlow models to TensorFlow Lite as of TensorFlow 1.9 -is `tf.contrib.lite.TocoConverter`. The API for calling the Python intepreter is -`tf.contrib.lite.Interpreter`. - -**NOTE**: As of TensorFlow 1.12, the API for converting TensorFlow models to -TFLite will be renamed to `TFLiteConverter`. `TFLiteConverter` is semantically -identically to `TocoConverter`. The API is available at -`tf.contrib.lite.TFLiteConverter` as of the Sept 26 `tf-nightly`. - -`TocoConverter` provides class methods based on the original format of the -model. `TocoConverter.from_session()` is available for GraphDefs. -`TocoConverter.from_saved_model()` is available for SavedModels. -`TocoConverter.from_keras_model_file()` is available for `tf.Keras` files. +is `tf.contrib.lite.TFLiteConverter`. The API for calling the Python intepreter +is `tf.contrib.lite.Interpreter`. + +Note: Reference "Additional Instructions" sections for converting TensorFlow +models to TensorFlow Lite +[in TensorFlow 1.9 to TensorFlow 1.11](#pre_tensorflow_1.11) and +[prior to TensorFlow 1.9](#pre_tensorflow_1.9) + +`TFLiteConverter` provides class methods based on the original format of the +model. `TFLiteConverter.from_session()` is available for GraphDefs. +`TFLiteConverter.from_saved_model()` is available for SavedModels. +`TFLiteConverter.from_keras_model_file()` is available for `tf.Keras` files. Example usages for simple float-point models are shown in [Basic Examples](#basic). Examples usages for more complex models is shown in [Complex Examples](#complex). -**NOTE**: Currently, `TocoConverter` will cause a fatal error to the Python -interpreter when the conversion fails. This will be remedied as soon as -possible. - ## Basic examples The following section shows examples of how to convert a basic float-point model from each of the supported data formats into a TensorFlow Lite FlatBuffers. -### Exporting a GraphDef from tf.Session +### Exporting a GraphDef from tf.Session The following example shows how to convert a TensorFlow GraphDef into a TensorFlow Lite FlatBuffer from a `tf.Session` object. @@ -76,12 +52,12 @@ out = tf.identity(val, name="out") with tf.Session() as sess: sess.run(tf.global_variables_initializer()) - converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out]) + converter = tf.contrib.lite.TFLiteConverter.from_session(sess, [img], [out]) tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model) ``` -### Exporting a GraphDef from file +### Exporting a GraphDef from file The following example shows how to convert a TensorFlow GraphDef into a TensorFlow Lite FlatBuffer when the GraphDef is stored in a file. Both `.pb` and @@ -89,7 +65,7 @@ TensorFlow Lite FlatBuffer when the GraphDef is stored in a file. Both `.pb` and The example uses [Mobilenet_1.0_224](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz). -The function only supports GraphDefs frozen via +The function only supports GraphDefs frozen using [freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py). ```python @@ -99,13 +75,13 @@ graph_def_file = "/path/to/Downloads/mobilenet_v1_1.0_224/frozen_graph.pb" input_arrays = ["input"] output_arrays = ["MobilenetV1/Predictions/Softmax"] -converter = tf.contrib.lite.TocoConverter.from_frozen_graph( +converter = tf.contrib.lite.TFLiteConverter.from_frozen_graph( graph_def_file, input_arrays, output_arrays) tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model) ``` -### Exporting a SavedModel +### Exporting a SavedModel The following example shows how to convert a SavedModel into a TensorFlow Lite FlatBuffer. @@ -113,25 +89,26 @@ FlatBuffer. ```python import tensorflow as tf -converter = tf.contrib.lite.TocoConverter.from_saved_model(saved_model_dir) +converter = tf.contrib.lite.TFLiteConverter.from_saved_model(saved_model_dir) tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model) ``` For more complex SavedModels, the optional parameters that can be passed into -`TocoConverter.from_saved_model()` are `input_arrays`, `input_shapes`, +`TFLiteConverter.from_saved_model()` are `input_arrays`, `input_shapes`, `output_arrays`, `tag_set` and `signature_key`. Details of each parameter are -available by running `help(tf.contrib.lite.TocoConverter)`. +available by running `help(tf.contrib.lite.TFLiteConverter)`. -### Exporting a tf.keras File +### Exporting a tf.keras File The following example shows how to convert a `tf.keras` model into a TensorFlow -Lite FlatBuffer. +Lite FlatBuffer. This example requires +[`h5py`](http://docs.h5py.org/en/latest/build.html) to be installed. ```python import tensorflow as tf -converter = tf.contrib.lite.TocoConverter.from_keras_model_file("keras_model.h5") +converter = tf.contrib.lite.TFLiteConverter.from_keras_model_file("keras_model.h5") tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model) ``` @@ -163,7 +140,7 @@ keras_file = "keras_model.h5" tf.keras.models.save_model(model, keras_file) # Convert to TensorFlow Lite model. -converter = tf.contrib.lite.TocoConverter.from_keras_model_file(keras_file) +converter = tf.contrib.lite.TFLiteConverter.from_keras_model_file(keras_file) tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model) ``` @@ -173,13 +150,13 @@ open("converted_model.tflite", "wb").write(tflite_model) For models where the default value of the attributes is not sufficient, the attribute's values should be set before calling `convert()`. In order to call any constants use `tf.contrib.lite.constants.` as seen below with -`QUANTIZED_UINT8`. Run `help(tf.contrib.lite.TocoConverter)` in the Python +`QUANTIZED_UINT8`. Run `help(tf.contrib.lite.TFLiteConverter)` in the Python terminal for detailed documentation on the attributes. Although the examples are demonstrated on GraphDefs containing only constants. The same logic can be applied irrespective of the input data format. -### Exporting a quantized GraphDef +### Exporting a quantized GraphDef The following example shows how to convert a quantized model into a TensorFlow Lite FlatBuffer. @@ -193,7 +170,7 @@ val = img + const out = tf.fake_quant_with_min_max_args(val, min=0., max=1., name="output") with tf.Session() as sess: - converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out]) + converter = tf.contrib.lite.TFLiteConverter.from_session(sess, [img], [out]) converter.inference_type = tf.contrib.lite.constants.QUANTIZED_UINT8 input_arrays = converter.get_input_arrays() converter.quantized_input_stats = {input_arrays[0] : (0., 1.)} # mean, std_dev @@ -203,7 +180,7 @@ with tf.Session() as sess: ## TensorFlow Lite Python interpreter -### Using the interpreter from a model file +### Using the interpreter from a model file The following example shows how to use the TensorFlow Lite Python interpreter when provided a TensorFlow Lite FlatBuffer file. The example also demonstrates @@ -233,7 +210,7 @@ output_data = interpreter.get_tensor(output_details[0]['index']) print(output_data) ``` -### Using the interpreter from model data +### Using the interpreter from model data The following example shows how to use the TensorFlow Lite Python interpreter when starting with the TensorFlow Lite Flatbuffer model previously loaded. This @@ -250,7 +227,7 @@ val = img + const out = tf.identity(val, name="out") with tf.Session() as sess: - converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out]) + converter = tf.contrib.lite.TFLiteConverter.from_session(sess, [img], [out]) tflite_model = converter.convert() # Load TFLite model and allocate tensors. @@ -260,15 +237,22 @@ interpreter.allocate_tensors() ## Additional instructions -### Build from source code +### Build from source code + +In order to run the latest version of the TensorFlow Lite Converter Python API, +either install the nightly build with +[pip](https://www.tensorflow.org/install/pip) (recommended) or +[Docker](https://www.tensorflow.org/install/docker), or +[build the pip package from source](https://www.tensorflow.org/install/source). + +### Converting models in TensorFlow 1.9 to TensorFlow 1.11 -In order to run the latest version of the TOCO Python API, clone the TensorFlow -repository, configure the installation, and build and install the pip package. -Detailed instructions are available -[here](https://www.tensorflow.org/install/source). +To convert TensorFlow models to TensorFlow Lite in TensorFlow 1.9 through +TensorFlow 1.11, use `TocoConverter`. `TocoConverter` is semantically +identically to `TFLiteConverter`. -### Converting models prior to TensorFlow 1.9. +### Converting models prior to TensorFlow 1.9 -To use TOCO in TensorFlow 1.7 and TensorFlow 1.8, use the `toco_convert` -function. Run `help(tf.contrib.lite.toco_convert)` to get details about accepted -parameters. +To convert TensorFlow models to TensorFlow Lite in TensorFlow 1.7 and TensorFlow +1.8, use the `toco_convert` function. Run `help(tf.contrib.lite.toco_convert)` +to get details about accepted parameters. diff --git a/tensorflow/contrib/lite/g3doc/images/convert/sample_after.png b/tensorflow/contrib/lite/g3doc/images/convert/sample_after.png new file mode 100644 index 0000000000000000000000000000000000000000..6c451f97903f7f70a9f28dee8abf6daeb7ec5693 Binary files /dev/null and b/tensorflow/contrib/lite/g3doc/images/convert/sample_after.png differ diff --git a/tensorflow/contrib/lite/g3doc/images/convert/sample_before.png b/tensorflow/contrib/lite/g3doc/images/convert/sample_before.png new file mode 100644 index 0000000000000000000000000000000000000000..e5317ef295062e79c66430512ef1c45925858ce0 Binary files /dev/null and b/tensorflow/contrib/lite/g3doc/images/convert/sample_before.png differ diff --git a/tensorflow/contrib/lite/g3doc/images/convert/workflow.svg b/tensorflow/contrib/lite/g3doc/images/convert/workflow.svg new file mode 100644 index 0000000000000000000000000000000000000000..3dfcbd67d8919bd1ffe2a09d7b291a7c3182fccd --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/images/convert/workflow.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/tensorflow/contrib/lite/g3doc/images/performance/model_size_vs_accuracy.png b/tensorflow/contrib/lite/g3doc/images/performance/model_size_vs_accuracy.png new file mode 100644 index 0000000000000000000000000000000000000000..44d0ccd3128dea1c947e57ccbc4e18b2d34cef88 Binary files /dev/null and b/tensorflow/contrib/lite/g3doc/images/performance/model_size_vs_accuracy.png differ diff --git a/tensorflow/contrib/lite/g3doc/images/performance/model_size_vs_latency.png b/tensorflow/contrib/lite/g3doc/images/performance/model_size_vs_latency.png new file mode 100644 index 0000000000000000000000000000000000000000..94a6310612828db2370d19a094795341478e90f8 Binary files /dev/null and b/tensorflow/contrib/lite/g3doc/images/performance/model_size_vs_latency.png differ diff --git a/tensorflow/contrib/lite/g3doc/performance.md b/tensorflow/contrib/lite/g3doc/performance.md index 0ae9400068887745f3064409bd39cf41eea212ca..ed114527166da79dba2d92c3ffad78e9885f9e94 100644 --- a/tensorflow/contrib/lite/g3doc/performance.md +++ b/tensorflow/contrib/lite/g3doc/performance.md @@ -3,36 +3,43 @@ Mobile and embedded devices have limited computational resources and it is important to keep your application resource efficient. We have compiled a list of best practices and strategies you can use to optimize your model and application when using Tensorflow Lite. -## Choose the most efficient model for the problem -Some models may be too large to run on embedded devices. Instead of large models it is better to use a slightly less precise but smaller model for embedded devices. Smaller models not only use less disk space and memory but are generally faster and more energy efficient. One example of models optimized for mobile devices are [MobileNets](https://arxiv.org/abs/1704.04861), which are optimized for mobile vision applications. Tensorflow Lite [models page](models.md) lists several other models that have been optimized specifically for mobile and embedded devices. +## Choose the best model for the task +Depending on the task you will need to make a tradeoff between model complexity and size. If your task requires high accuracy then you may need a large and complex model. Some tasks may work with a less precise model, for these tasks it is better to use a smaller but less precise model. Smaller models not only use less disk space and memory but are generally faster and more energy efficient. For example, graphs below show accuracy and latency tradeoff for some common image classification models. + +![accuracy vs model size](images/performance/model_size_vs_accuracy.png "Accuracy vs Model size") + + +![latency vs model size](images/performance/model_size_vs_latency.png "Latency vs Model size") + +One example of models optimized for mobile devices are [MobileNets](https://arxiv.org/abs/1704.04861), which are optimized for mobile vision applications. Tensorflow Lite [models page](models.md) lists several other models that have been optimized specifically for mobile and embedded devices. You can retrain the listed models on your own dataset by using transfer learning. Check out our transfer learning tutorial for -[image classification] (https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/#0) and +[image classification](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/#0) and [object detection](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193). ## Profile your model -Before starting any optimization, it is a good practice to profile and benchmark your model. Tensorflow Lite [benchmarking tool](../tools/benchmark) has a built-in profiler that shows per operator profiling statistics. This can help in understanding performance bottlenecks and which operators dominate the computation time. +Once you have selected a candidate model that is right for your task, it is a good practice to profile and benchmark your model. Tensorflow Lite [benchmarking tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark) has a built-in profiler that shows per operator profiling statistics. This can help in understanding performance bottlenecks and which operators dominate the computation time. ## Profile and optimize operators in the graph If a particular operator appears frequently in the model and based on profiling you find the operator consuming the most amount of time, you can look into optimizing the operator. This scenario should be rare as Tensorflow Lite has optimized versions for most ops. However you may be able to write a faster version of a custom op, if you know the constraints in which the operator is executed. Check out our [custom operator documentation](custom_operators.md). ## Quantize your model -If your model uses floating point weights or activations then it may be possible to reduce the size of model up to ~4x by using quantization and other model optimizations. Check out our [model optimization toolkit](https://www.tensorflow.org/performance/model_optimization) for details about optimizing your model. Fully quantized models can be remarkably power efficient as well. +If your model uses floating point weights or activations then it may be possible to reduce the size of model up to ~4x by using quantization and other model optimizations. Check out our [model optimization toolkit](https://www.tensorflow.org/performance/model_optimization) for details about optimizing your model. ## Tweak the number of threads -Tensorflow Lite supports multi-threaded kernels for many operators. You can increase the number of threads and speed up execution of operators. Increasing the number of threads will however make your model use more resources and power. For some applications latency may be more important than energy efficiency. You can increase the number of threads by setting the number of [interpreter](../interpreter.h) threads. +Tensorflow Lite supports multi-threaded kernels for many operators. You can increase the number of threads and speed up execution of operators. Increasing the number of threads will however make your model use more resources and power. For some applications latency may be more important than energy efficiency. You can increase the number of threads by setting the number of [interpreter](https://github.com/tensorflow/tensorflow/blob/1084594657a5d139102ac794f84d1427a710e39a/tensorflow/contrib/lite/interpreter.h#L337) threads. Multi-threaded execution however comes at the cost of increased performance variability depending on what else is been executed concurrently. This is particularly the case for mobile apps. For example, isolated tests may show 2x speed up vs single-threaded but if another app is executing at the same time may result in worst performance than single-threaded. ## Eliminate redundant copies -Tensorflow Lite is optimized to reduce redundant copies. The APIs allow user to [mmap a model file](https://github.com/tensorflow/tensorflow/blob/9982fd6c8831cbd2f58954f79ea71f26660393bc/tensorflow/contrib/lite/model.h#L152) and avoid copies. If your application is not careful, there can be redundant copies when feeding the input to the model and reading output from the model. Make sure to eliminate redundant copies. If you are using higher level APIs like Java API, make sure to carefully check the documentation for performance caveats. For example, the Java API is a lot faster if ByteBuffers are used as [inputs](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java#L151). +If your application is not careful, there can be redundant copies when feeding the input to the model and reading output from the model. Make sure to eliminate redundant copies. If you are using higher level APIs like Java API, make sure to carefully check the documentation for performance caveats. For example, the Java API is a lot faster if ByteBuffers are used as [inputs](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java#L151). ## Profile your application with platform specific tools Platform specific tools like [Android profiler](https://developer.android.com/studio/profile/android-profiler) and [Instruments](https://help.apple.com/instruments/mac/current/) provide a wealth of profiling information that can be used to debug your app. Sometimes the performance bug may be not in the model but in parts of application code that interact with the model. Make sure to familiarize yourself with platform specific profiling tools and best practices for your platform. -## Use hardware accelerators available on the device -Tensorflow Lite is working on adding support for accelerators like GPU and provides acceleration through [NNAPI](https://developer.android.com/ndk/guides/neuralnetworks/) on Android. -You can utilize these hardware accelerator backends to improve the speed and efficiency of your model. To enable NNAPI call [UseNNAPI](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/contrib/lite/interpreter.h#L334) on the interpreter instance. +## Evaluate whether your model benefits from using hardware accelerators available on the device +Tensorflow Lite is working on adding support for accelerators like GPU and provides acceleration through [Neural Networks API](https://developer.android.com/ndk/guides/neuralnetworks/) on Android. +You can utilize these hardware accelerator backends to improve the speed and efficiency of your model. To enable Neural Networks API call [UseNNAPI](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/contrib/lite/interpreter.h#L334) on the interpreter instance. ## Need more help The Tensorflow team is happy to help diagnose and address specific performance issues you may be facing. Please file a bug on [github](https://github.com/tensorflow/tensorflow/issues) with details of the issue. diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md index b0f32a8d6ca91229489c73c2c6f52d9c82d37b37..2eb776d10cf8ec68987d13b580eddf2f1bda8e78 100644 --- a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md @@ -1,6 +1,22 @@ - # Building TensorFlow on Android +Warning: We expect to deprecate TensorFlow Mobile in early 2019 + +
+

+ TensorFlow Lite is our main mobile and embedded offering. We are + working hard to close the feature gap between TensorFlow Mobile and + TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We + will give ample notice to our users when we get to that point and will + provide help and support to ensure easy migrations. +

+

+ In the meantime, please use TensorFlow Lite. If you have a feature request, + such as a missing op, please post to our GitHub. +

+
+ To get you started working with TensorFlow on Android, we'll walk through two ways to build our TensorFlow mobile demos and deploying them on an Android device. The first is Android Studio, which lets you build and deploy in an diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/index.md b/tensorflow/contrib/lite/g3doc/tfmobile/index.md index 49ad35d4e6a18f266d88e330626bae8bf1fc499f..15f0fd396134e40e89266182cb308080d9d250cb 100644 --- a/tensorflow/contrib/lite/g3doc/tfmobile/index.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/index.md @@ -1,6 +1,22 @@ - # Overview +Warning: We expect to deprecate TensorFlow Mobile in early 2019 + +
+

+ TensorFlow Lite is our main mobile and embedded offering. We are + working hard to close the feature gap between TensorFlow Mobile and + TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We + will give ample notice to our users when we get to that point and will + provide help and support to ensure easy migrations. +

+

+ In the meantime, please use TensorFlow Lite. If you have a feature request, + such as a missing op, please post to our GitHub. +

+
+ TensorFlow was designed to be a good deep learning solution for mobile platforms. Currently we have two solutions for deploying machine learning applications on mobile and embedded devices: TensorFlow for Mobile and diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md index be8b4100c89f4b02e651b1585faf438881c9119d..d922907cdc5fe5ccec8864b456586fce0293a0af 100644 --- a/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md @@ -1,6 +1,22 @@ - # Building TensorFlow on iOS +Warning: We expect to deprecate TensorFlow Mobile in early 2019 + +
+

+ TensorFlow Lite is our main mobile and embedded offering. We are + working hard to close the feature gap between TensorFlow Mobile and + TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We + will give ample notice to our users when we get to that point and will + provide help and support to ensure easy migrations. +

+

+ In the meantime, please use TensorFlow Lite. If you have a feature request, + such as a missing op, please post to our GitHub. +

+
+ ## Using CocoaPods The simplest way to get started with TensorFlow on iOS is using the CocoaPods diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md b/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md index 4d4bb3bc081d613714271f8b0bf7461cb1e0f4d5..fd0e322c93493ed835ae7ec9766a708885c6ac88 100644 --- a/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md @@ -1,6 +1,22 @@ - # Integrating TensorFlow libraries +Warning: We expect to deprecate TensorFlow Mobile in early 2019 + +
+

+ TensorFlow Lite is our main mobile and embedded offering. We are + working hard to close the feature gap between TensorFlow Mobile and + TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We + will give ample notice to our users when we get to that point and will + provide help and support to ensure easy migrations. +

+

+ In the meantime, please use TensorFlow Lite. If you have a feature request, + such as a missing op, please post to our GitHub. +

+
+ Once you have made some progress on a model that addresses the problem you’re trying to solve, it’s important to test it out inside your application immediately. There are often unexpected differences between your training data diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md b/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md index 7436594fd8580151ba66562eccd408cc7e6c4201..59ff8e774c6c63a01668aee7d6caeea01171468d 100644 --- a/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md @@ -1,6 +1,22 @@ - # Optimizing for mobile +Warning: We expect to deprecate TensorFlow Mobile in early 2019 + +
+

+ TensorFlow Lite is our main mobile and embedded offering. We are + working hard to close the feature gap between TensorFlow Mobile and + TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We + will give ample notice to our users when we get to that point and will + provide help and support to ensure easy migrations. +

+

+ In the meantime, please use TensorFlow Lite. If you have a feature request, + such as a missing op, please post to our GitHub. +

+
+ There are some special issues that you have to deal with when you’re trying to ship on mobile or embedded devices, and you’ll need to think about these as you’re developing your model. diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md b/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md index d1c67d4c61608bcbc9b0bcee5b60f46a73b44692..1d373251ddf3ba6a0119bd57bf14caf100ef371a 100644 --- a/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md @@ -1,6 +1,22 @@ - # Preparing models for mobile deployment +Warning: We expect to deprecate TensorFlow Mobile in early 2019 + +
+

+ TensorFlow Lite is our main mobile and embedded offering. We are + working hard to close the feature gap between TensorFlow Mobile and + TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We + will give ample notice to our users when we get to that point and will + provide help and support to ensure easy migrations. +

+

+ In the meantime, please use TensorFlow Lite. If you have a feature request, + such as a missing op, please post to our GitHub. +

+
+ The requirements for storing model information during training are very different from when you want to release it as part of a mobile app. This section covers the tools involved in converting from a training model to something diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 88e41ffc55d2b666bb4837c12dccb2ebcdcaac33..c72e7bf33ebbaac09916ffda6faf4b812d702ea8 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -952,7 +952,10 @@ TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate, } } if (has_dynamic_tensors) { - ReportError(&context_, "Attempting to resize a fixed-size tensor."); + ReportError( + &context_, + "Attempting to use a delegate that only supports static-sized " + "tensors with a graph that has dynamic-sized tensors."); return kTfLiteError; } } diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 7ef736d01b9b8a91ca2deb107bfbe92399405233..651a97e9dc84350569514528ae5635ec040d607f 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -349,6 +349,10 @@ class Interpreter { return context_.allow_fp32_relax_to_fp16; } + // Owning handle to a TfLiteDelegate instance. + using TfLiteDelegatePtr = + std::unique_ptr; + // Allow a delegate to look at the graph and modify the graph to handle // parts of the graph themselves. After this is called, the graph may // contain new nodes that replace 1 more nodes. @@ -574,19 +578,11 @@ class Interpreter { TfLiteExternalContextType type, TfLiteExternalContext* ctx); - using TfLiteDelegatePtr = - std::unique_ptr; - // Variant of the public ModifyGraphWithDelegate method that additionally // Assumes ownership of the provided delegate. // WARNING: This is an experimental API and subject to change. - template - TfLiteStatus ModifyGraphWithDelegate(std::unique_ptr typed_delegate, + TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegatePtr delegate, bool allow_dynamic_tensors = false) { - TfLiteDelegatePtr delegate(typed_delegate.release(), - [](TfLiteDelegate* delegate) { - delete static_cast(delegate); - }); // Note that we retain ownership of the delegate even if graph modification // fails, as delegate use will be in an indeterminate state at that point. owned_delegates_.push_back(std::move(delegate)); @@ -676,6 +672,7 @@ class Interpreter { // List of delegates that have been installed and are owned by this // interpreter instance. Useful if client delegate ownership is burdensome. // WARNING: This is an experimental API and subject to change. + // TODO(b/116667551): Use TfLiteExternalContext for storing state. std::vector owned_delegates_; std::unique_ptr memory_planner_; diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index cdede430e29be7b18939f55a8bb06b66f1a3ea33..6c71d5a8d7bb3e275379637b151ab8f998b04f41 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -30,7 +30,11 @@ class InterpreterTest : public ::testing::Test { template static TfLiteStatus ModifyGraphWithDelegate( Interpreter* interpreter, std::unique_ptr delegate) { - return interpreter->ModifyGraphWithDelegate(std::move(delegate)); + Interpreter::TfLiteDelegatePtr tflite_delegate( + delegate.release(), [](TfLiteDelegate* delegate) { + delete reinterpret_cast(delegate); + }); + return interpreter->ModifyGraphWithDelegate(std::move(tflite_delegate)); } protected: diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD index 098ba7e7731d833678fbd5eab9cce3f022570f23..cab8d5277f2d3f539e7a69f15ebda20821b19a3b 100644 --- a/tensorflow/contrib/lite/java/BUILD +++ b/tensorflow/contrib/lite/java/BUILD @@ -11,6 +11,10 @@ load("//tensorflow/java:build_defs.bzl", "JAVACOPTS") load("//tensorflow/contrib/lite:build_def.bzl", "tflite_jni_binary") load("//tensorflow/contrib/lite/java:aar_with_jni.bzl", "aar_with_jni") +JAVA_SRCS = glob([ + "src/main/java/org/tensorflow/lite/*.java", +]) + # Building tensorflow-lite.aar including 4 variants of .so # To build an aar for release, run below command: # bazel build --cxxopt='--std=c++11' -c opt --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ @@ -20,28 +24,38 @@ aar_with_jni( android_library = ":tensorflowlite", ) +# EXPERIMENTAL: AAR target that supports TensorFlow op execution with TFLite. +aar_with_jni( + name = "tensorflow-lite-flex", + android_library = ":tensorflowlite_flex", +) + android_library( name = "tensorflowlite", - srcs = glob( - [ - "src/main/java/org/tensorflow/lite/*.java", - ], - ), + srcs = JAVA_SRCS, + manifest = "AndroidManifest.xml", + visibility = ["//visibility:public"], + deps = [ + ":tensorflowlite_native", + "@org_checkerframework_qual", + ], +) + +# EXPERIMENTAL: Android target that supports TensorFlow op execution with TFLite. +android_library( + name = "tensorflowlite_flex", + srcs = JAVA_SRCS, manifest = "AndroidManifest.xml", visibility = ["//visibility:public"], deps = [ - ":tflite_runtime", + ":tensorflowlite_native_flex", "@org_checkerframework_qual", ], ) android_library( name = "tensorflowlite_java", - srcs = glob( - [ - "src/main/java/org/tensorflow/lite/*.java", - ], - ), + srcs = JAVA_SRCS, visibility = ["//visibility:public"], deps = [ "@org_checkerframework_qual", @@ -50,16 +64,23 @@ android_library( java_library( name = "tensorflowlitelib", - srcs = glob( - [ - "src/main/java/org/tensorflow/lite/*.java", - ], - ), + srcs = JAVA_SRCS, javacopts = JAVACOPTS, visibility = ["//visibility:public"], deps = [ ":libtensorflowlite_jni.so", - "//tensorflow/contrib/lite/java/src/main/native", + "@org_checkerframework_qual", + ], +) + +# EXPERIMENTAL: Java target that supports TensorFlow op execution with TFLite. +java_library( + name = "tensorflowlitelib_flex", + srcs = JAVA_SRCS, + javacopts = JAVACOPTS, + visibility = ["//visibility:public"], + deps = [ + ":libtensorflowlite_flex_jni.so", "@org_checkerframework_qual", ], ) @@ -72,7 +93,6 @@ java_test( tags = ["no_oss"], test_class = "org.tensorflow.lite.TensorFlowLiteTest", deps = [ - ":libtensorflowlite_jni.so", ":tensorflowlitelib", "@com_google_truth", "@junit", @@ -87,7 +107,6 @@ java_test( tags = ["no_oss"], test_class = "org.tensorflow.lite.DataTypeTest", deps = [ - ":libtensorflowlite_jni.so", ":tensorflowlitelib", "@com_google_truth", "@junit", @@ -110,7 +129,6 @@ java_test( tags = ["no_oss"], test_class = "org.tensorflow.lite.NativeInterpreterWrapperTest", deps = [ - ":libtensorflowlite_jni.so", ":tensorflowlitelib", "@com_google_truth", "@junit", @@ -125,19 +143,37 @@ java_test( data = [ "src/testdata/add.bin", "src/testdata/mobilenet.tflite.bin", + "//tensorflow/contrib/lite:testdata/multi_add_flex.bin", ], javacopts = JAVACOPTS, tags = ["no_oss"], test_class = "org.tensorflow.lite.InterpreterTest", visibility = ["//visibility:private"], deps = [ - ":libtensorflowlite_jni.so", ":tensorflowlitelib", "@com_google_truth", "@junit", ], ) +java_test( + name = "InterpreterFlexTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/lite/InterpreterFlexTest.java"], + data = [ + "//tensorflow/contrib/lite:testdata/multi_add_flex.bin", + ], + javacopts = JAVACOPTS, + tags = ["no_oss"], + test_class = "org.tensorflow.lite.InterpreterFlexTest", + visibility = ["//visibility:private"], + deps = [ + ":tensorflowlitelib_flex", + "@com_google_truth", + "@junit", + ], +) + java_test( name = "TensorTest", size = "small", @@ -164,14 +200,30 @@ filegroup( ) cc_library( - name = "tflite_runtime", + name = "tensorflowlite_native", srcs = ["libtensorflowlite_jni.so"], visibility = ["//visibility:public"], ) +cc_library( + name = "tensorflowlite_native_flex", + srcs = ["libtensorflowlite_flex_jni.so"], + visibility = ["//visibility:public"], +) + tflite_jni_binary( name = "libtensorflowlite_jni.so", deps = [ "//tensorflow/contrib/lite/java/src/main/native", ], ) + +# EXPERIMENTAL: Native target that supports TensorFlow op execution with TFLite. +tflite_jni_binary( + name = "libtensorflowlite_flex_jni.so", + deps = [ + "//tensorflow/contrib/lite/delegates/flex:delegate", + "//tensorflow/contrib/lite/java/src/main/native", + "//tensorflow/contrib/lite/java/src/main/native:init_tensorflow", + ], +) diff --git a/tensorflow/contrib/lite/java/aar_with_jni.bzl b/tensorflow/contrib/lite/java/aar_with_jni.bzl index 9d2aead266f897e8b08520d06ea60654927029e9..360d622b1bcf5cf379987ceefc43c74b1b6ce5fb 100644 --- a/tensorflow/contrib/lite/java/aar_with_jni.bzl +++ b/tensorflow/contrib/lite/java/aar_with_jni.bzl @@ -30,7 +30,10 @@ EOF # In some platforms we don't have an Android SDK/NDK and this target # can't be built. We need to prevent the build system from trying to # use the target in that case. - tags = ["manual"], + tags = [ + "manual", + "no_cuda_on_cpu_tap", + ], ) native.genrule( diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/contrib/lite/java/ovic/BUILD index bb0be04ca2a659dfb5e0c73bcf0485fb425d5ed0..ea9b9ed4b66a601981f4c402f7f8a4f6749e07fd 100644 --- a/tensorflow/contrib/lite/java/ovic/BUILD +++ b/tensorflow/contrib/lite/java/ovic/BUILD @@ -9,6 +9,7 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow/java:build_defs.bzl", "JAVACOPTS") +# Build targets for OVIC classification. java_test( name = "OvicClassifierTest", size = "medium", @@ -45,8 +46,9 @@ android_library( name = "ovicbenchmarkerlib", srcs = [ "src/main/java/org/tensorflow/ovic/OvicBenchmarker.java", + "src/main/java/org/tensorflow/ovic/OvicClassificationResult.java", "src/main/java/org/tensorflow/ovic/OvicClassifier.java", - "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", + "src/main/java/org/tensorflow/ovic/OvicClassifierBenchmarker.java", ], manifest = "//tensorflow/contrib/lite/java:AndroidManifest.xml", tags = ["no_oss"], @@ -60,8 +62,8 @@ android_library( java_library( name = "ovicbenchmarkerlib_java", srcs = [ + "src/main/java/org/tensorflow/ovic/OvicClassificationResult.java", "src/main/java/org/tensorflow/ovic/OvicClassifier.java", - "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", ], javacopts = JAVACOPTS, tags = ["no_oss"], @@ -73,3 +75,58 @@ java_library( "@org_checkerframework_qual", ], ) + +# Build targets for OVIC detection. +java_test( + name = "OvicDetectorTest", + size = "medium", + srcs = ["src/test/java/org/tensorflow/ovic/OvicDetectorTest.java"], + data = [ + "//tensorflow/contrib/lite/java/ovic/src/testdata:coco_labels.txt", + "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata", + "@tflite_mobilenet_ssd_quant//:detect.tflite", + ], + javacopts = JAVACOPTS, + tags = ["no_oss"], + test_class = "org.tensorflow.ovic.OvicDetectorTest", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/lite/java/ovic:ovicdetectionbenchmarkerlib_java", + "@com_google_truth", + "@junit", + ], +) + +android_library( + name = "ovicdetectionbenchmarkerlib", + srcs = [ + "src/main/java/org/tensorflow/ovic/BoundingBox.java", + "src/main/java/org/tensorflow/ovic/OvicBenchmarker.java", + "src/main/java/org/tensorflow/ovic/OvicDetectionResult.java", + "src/main/java/org/tensorflow/ovic/OvicDetector.java", + "src/main/java/org/tensorflow/ovic/OvicDetectorBenchmarker.java", + ], + manifest = "//tensorflow/contrib/lite/java:AndroidManifest.xml", + deps = [ + "//tensorflow/contrib/lite/java:tensorflowlite", + "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", + "@org_checkerframework_qual", + ], +) + +java_library( + name = "ovicdetectionbenchmarkerlib_java", + srcs = [ + "src/main/java/org/tensorflow/ovic/BoundingBox.java", + "src/main/java/org/tensorflow/ovic/OvicDetectionResult.java", + "src/main/java/org/tensorflow/ovic/OvicDetector.java", + ], + javacopts = JAVACOPTS, + deps = [ + "//tensorflow/contrib/lite/java:libtensorflowlite_jni.so", + "//tensorflow/contrib/lite/java:tensorflowlite_java", + "//tensorflow/contrib/lite/java/src/main/native", + "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", + "@org_checkerframework_qual", + ], +) diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD index 058240aada52fa533aeb81997f5ad7bbc11f8b42..f567358ea33966ea8fdb422749662e22111c5fcc 100644 --- a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD +++ b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD @@ -10,8 +10,10 @@ android_binary( ], aapt_version = "aapt", assets = [ - "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata", + "//tensorflow/contrib/lite/java/ovic/src/testdata:coco_labels.txt", "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", + "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata", + "@tflite_mobilenet_ssd_quant//:detect.tflite", ], assets_dir = "", custom_package = "ovic.demo.app", @@ -25,6 +27,7 @@ android_binary( deps = [ "//tensorflow/contrib/lite/java:tensorflowlite", "//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib", + "//tensorflow/contrib/lite/java/ovic:ovicdetectionbenchmarkerlib", "@androidsdk//com.android.support:support-v13-25.2.0", "@androidsdk//com.android.support:support-v4-25.2.0", ], diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java index 4adf94aeb6a431adc3c574bd4d3f0418538884f2..48c29ecebeed42ac9a2e0bc801cab1fb1f9201e8 100644 --- a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java +++ b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java @@ -35,19 +35,18 @@ import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.text.DecimalFormat; import org.tensorflow.ovic.OvicBenchmarker; -import org.tensorflow.ovic.OvicSingleImageResult; - +import org.tensorflow.ovic.OvicClassifierBenchmarker; +import org.tensorflow.ovic.OvicDetectorBenchmarker; /** Class that benchmark image classifier models. */ public class OvicBenchmarkerActivity extends Activity { /** Tag for the {@link Log}. */ private static final String TAG = "OvicBenchmarkerActivity"; - /** Name of the label file stored in Assets. */ - private static final String LABEL_PATH = "labels.txt"; - - private static final String TEST_IMAGE_PATH = "test_image_224.jpg"; - private static final String MODEL_PATH = "float_model.lite"; + /** Name of the task-dependent data files stored in Assets. */ + private static String labelPath = null; + private static String testImagePath = null; + private static String modelPath = null; /** * Each bottom press will launch a benchmarking experiment. The experiment stops when either the * total native latency reaches WALL_TIME or the number of iterations reaches MAX_ITERATIONS, @@ -66,8 +65,6 @@ public class OvicBenchmarkerActivity extends Activity { private MappedByteBuffer model = null; private InputStream labelInputStream = null; private OvicBenchmarker benchmarker; - /** Inference result of each iteration. */ - OvicSingleImageResult iterResult = null; private TextView textView = null; // private Button startButton = null; @@ -83,21 +80,31 @@ public class OvicBenchmarkerActivity extends Activity { } private Bitmap loadTestBitmap() throws IOException { - InputStream imageStream = getAssets().open(TEST_IMAGE_PATH); + InputStream imageStream = getAssets().open(testImagePath); return BitmapFactory.decodeStream(imageStream); } - public void initializeTest() throws IOException { + public void initializeTest(boolean benchmarkClassification) throws IOException { Log.i(TAG, "Initializing benchmarker."); - benchmarker = new OvicBenchmarker(WALL_TIME); + if (benchmarkClassification) { + benchmarker = new OvicClassifierBenchmarker(WALL_TIME); + labelPath = "labels.txt"; + testImagePath = "test_image_224.jpg"; + modelPath = "quantized_model.lite"; + } else { // Benchmarking detection. + benchmarker = new OvicDetectorBenchmarker(WALL_TIME); + labelPath = "coco_labels.txt"; + testImagePath = "test_image_224.jpg"; + modelPath = "detect.tflite"; + } AssetManager am = getAssets(); - AssetFileDescriptor fileDescriptor = am.openFd(MODEL_PATH); + AssetFileDescriptor fileDescriptor = am.openFd(modelPath); FileInputStream modelInputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); FileChannel fileChannel = modelInputStream.getChannel(); long startOffset = fileDescriptor.getStartOffset(); long declaredLength = fileDescriptor.getDeclaredLength(); model = fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); - labelInputStream = am.open(LABEL_PATH); + labelInputStream = am.open(labelPath); } public Boolean doTestIteration() throws IOException, InterruptedException { @@ -117,24 +124,44 @@ public class OvicBenchmarkerActivity extends Activity { Log.i(TAG, "Going to do test iter."); // Start testing. Bitmap testImageBitmap = loadTestBitmap(); - iterResult = benchmarker.doTestIteration(testImageBitmap); - testImageBitmap.recycle(); - if (iterResult == null) { + try { + if (!benchmarker.processBitmap(testImageBitmap)) { + throw new RuntimeException("Failed to run test."); + } + } catch (Exception e) { + e.printStackTrace(); + throw e; + } finally { + testImageBitmap.recycle(); + } + String iterResultString = benchmarker.getLastResultString(); + if (iterResultString == null) { throw new RuntimeException("Inference failed to produce a result."); } - Log.i(TAG, iterResult.toString()); + Log.i(TAG, iterResultString); return true; } - public void startPressed(View view) throws IOException { - Log.i(TAG, "Start pressed"); + public void detectPressed(View view) throws IOException { + benchmarkSession(false); + } + public void classifyPressed(View view) throws IOException { + benchmarkSession(true); + } + + private void benchmarkSession(boolean benchmarkClassification) throws IOException { try { - initializeTest(); + initializeTest(benchmarkClassification); } catch (IOException e) { Log.e(TAG, "Can't initialize benchmarker.", e); throw e; } String displayText = ""; + if (benchmarkClassification) { + displayText = "Classification benchmark: "; + } else { + displayText = "Detection benchmark: "; + } try { setProcessorAffinity(BIG_CORE_MASK); } catch (IOException e) { @@ -144,7 +171,6 @@ public class OvicBenchmarkerActivity extends Activity { Log.i(TAG, "Successfully initialized benchmarker."); int testIter = 0; Boolean iterSuccess = false; - double totalLatency = 0.0f; while (testIter < MAX_ITERATIONS) { try { iterSuccess = doTestIteration(); @@ -153,23 +179,22 @@ public class OvicBenchmarkerActivity extends Activity { throw e; } catch (InterruptedException e) { Log.e(TAG, "Interrupted at iteration " + testIter); + displayText += e.getMessage() + "\n"; } if (!iterSuccess) { break; } testIter++; - totalLatency += (double) iterResult.latency; } - ; Log.i(TAG, "Benchmarking finished"); if (textView != null) { if (testIter > 0) { textView.setText( displayText - + MODEL_PATH + + modelPath + ": Average latency=" - + df2.format(totalLatency / testIter) + + df2.format(benchmarker.getTotalRunTime() / testIter) + "ms after " + testIter + " runs."); diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml b/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml index e9d83bae543ae62ba8749c4c91b36b20bf09a176..1bce60ff7def2b0df9c93a4106a9aafff0009a2f 100644 --- a/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml +++ b/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml @@ -30,14 +30,14 @@ android:layout_height="wrap_content" android:text="@string/initial_status_msg" android:id="@+id/textView" - android:layout_above="@+id/button_start" + android:layout_above="@+id/button_clf_start" android:layout_alignParentTop="true"/>